Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate new cache system for training #472

Merged
merged 20 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)


leaf_prepare_4d_causal_attention_mask = torch.fx._symbolic_trace._create_wrapped_func(
_prepare_4d_causal_attention_mask
)

leaf_prepare_4d_causal_attention_mask = torch.fx._symbolic_trace._create_wrapped_func(
_prepare_4d_causal_attention_mask
)

leaf_prepare_4d_causal_attention_maskfor_sdpa = torch.fx._symbolic_trace._create_wrapped_func(
_prepare_4d_causal_attention_mask_for_sdpa
)
106 changes: 60 additions & 46 deletions optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@
import sys
import time
import warnings
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from accelerate import __version__ as accelerate_version
from packaging import version
from torch.utils.data import Dataset
from transformers import PreTrainedModel, Seq2SeqTrainer, Trainer, TrainingArguments
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.integrations import hp_params
Expand All @@ -56,6 +55,7 @@
EvalLoopOutput,
EvalPrediction,
HPSearchBackend,
PredictionOutput,
TrainOutput,
denumpify_detensorize,
has_length,
Expand All @@ -74,7 +74,8 @@
is_torch_xla_available,
patch_within_function,
)
from .utils.cache_utils import get_neuron_cache_path, set_neuron_cache_path
from .utils.cache_utils import get_hf_hub_cache_repos
from .utils.hub_neuronx_cache import hub_neuronx_cache, patch_neuron_cc_wrapper, synchronize_hub_cache
from .utils.require_utils import requires_neuronx_distributed
from .utils.training_utils import (
TRANSFORMERS_MIN_VERSION_USE_ACCELERATE,
Expand Down Expand Up @@ -113,38 +114,11 @@
if KEEP_HF_HUB_PROGRESS_BARS is None:
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"

# Used for torch.distributed.
_ORIGINAL_NEURON_CACHE_PATH: Optional[Path] = None
_TMP_NEURON_CACHE_DIR: Optional[TemporaryDirectory] = None
_TMP_NEURON_CACHE_PATH: Optional[Path] = None
_TCP_STORE_ADDRESS = "127.0.0.1"
_TCP_STORE_PORT = 5000


if os.environ.get("TORCHELASTIC_RUN_ID"):
import torch_xla.distributed.xla_backend as xbn

if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
_ORIGINAL_NEURON_CACHE_PATH = get_neuron_cache_path()

# _ORIGINAL_NEURON_CACHE_PATH is `None` when the `--no-cache` flag is set.
if _ORIGINAL_NEURON_CACHE_PATH is not None:
if is_precompilation():
# During precompilation, we make sure to set the cache path to the defined compile cache path by the
# user. If nothing is specified, it is set to the default compile cache used by the Neuron compiler:
# /var/tmp/neuron-compile-cache
set_neuron_cache_path(_ORIGINAL_NEURON_CACHE_PATH)
else:
if os.environ["RANK"] == "0":
_TMP_NEURON_CACHE_DIR = NeuronCacheCallback.create_temporary_neuron_cache(get_neuron_cache_path())
store = torch.distributed.TCPStore(_TCP_STORE_ADDRESS, _TCP_STORE_PORT, is_master=True)
store.set("tmp_neuron_cache_path", _TMP_NEURON_CACHE_DIR.name)
_TMP_NEURON_CACHE_PATH = Path(_TMP_NEURON_CACHE_DIR.name)
else:
store = torch.distributed.TCPStore(_TCP_STORE_ADDRESS, _TCP_STORE_PORT, is_master=False)
_TMP_NEURON_CACHE_PATH = Path(store.get("tmp_neuron_cache_path").decode("utf-8"))
set_neuron_cache_path(_TMP_NEURON_CACHE_PATH)

torch.distributed.init_process_group(backend="xla")
if not isinstance(torch.distributed.group.WORLD, xbn.ProcessGroupXla):
raise AssertionError("Failed to initialize torch.distributed process group using XLA backend.")
Expand Down Expand Up @@ -194,19 +168,6 @@ def __init__(self, *args, **kwargs):
if self.args.local_rank <= 0:
logger.setLevel(logging.INFO)

push = self.args.local_rank <= 0 and not is_precompilation() and not self.args.skip_cache_push
fetch = self.args.local_rank <= 0 or self.args.mp_plugin.should_parallelize

callback = NeuronCacheCallback(
tmp_neuron_cache=_TMP_NEURON_CACHE_PATH,
original_neuron_cache_path=_ORIGINAL_NEURON_CACHE_PATH,
fetch=fetch,
push=push,
wait_for_everyone_on_fetch=True,
wait_for_everyone_on_push=True,
)
self.add_callback(callback)

# Make the model Neuron-compatible for generation.
patch_generation_mixin_to_neuron_generation_mixin(self.model)

Expand Down Expand Up @@ -500,10 +461,16 @@ def _save_xla(self, output_dir: Optional[str] = None):

def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
if not os.environ.get("NEURON_PARALLEL_COMPILE"): # Avoid unnecessary model saving during precompilation
if output_dir is None:
output_dir = self.args.output_dir
with patch_neuron_cc_wrapper():
with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]):
if output_dir is None:
output_dir = self.args.output_dir

self._save_xla(output_dir)
self._save_xla(output_dir)

if xm.get_ordinal() == 0:
synchronize_hub_cache(get_hf_hub_cache_repos()[0])
xm.rendezvous("Hub cache synchronization done")

# Push to the Hub when `save_model` is called by the user.
if self.args.push_to_hub and not _internal_call:
Expand Down Expand Up @@ -1315,6 +1282,53 @@ def evaluation_loop(

return EvalLoopOutput(predictions=all_preds, label_ids=all_labels, metrics=metrics, num_samples=num_samples)

def train(
self,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
trial=None, # No type-annotation for this one because it is related to the optuna package.
ignore_keys_for_eval: Optional[List[str]] = None,
**kwargs,
):
with patch_neuron_cc_wrapper():
with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]):
result = super().train(
resume_from_checkpoint=resume_from_checkpoint,
trial=trial,
ignore_keys_for_eval=ignore_keys_for_eval,
**kwargs,
)
if xm.get_ordinal() == 0:
synchronize_hub_cache(get_hf_hub_cache_repos()[0])
xm.rendezvous("Hub cache synchronization done")
return result

def evaluate(
self,
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> Dict[str, float]:
with patch_neuron_cc_wrapper():
with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]):
result = super().evaluate(
eval_dataset=eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
if xm.get_ordinal() == 0:
synchronize_hub_cache(get_hf_hub_cache_repos()[0])
xm.rendezvous("Hub cache synchronization done")
return result

def predict(
self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None, metric_key_prefix: str = "test"
) -> PredictionOutput:
with patch_neuron_cc_wrapper():
with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]):
result = super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
if xm.get_ordinal() == 0:
synchronize_hub_cache(get_hf_hub_cache_repos()[0])
xm.rendezvous("Hub cache synchronization done")
return result


class NeuronTrainer(AugmentTrainerForNeuronMixin, Trainer):
"""
Expand Down
47 changes: 44 additions & 3 deletions optimum/neuron/utils/hub_neuronx_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
import logging
import os
import shutil
from contextlib import contextmanager
from pathlib import Path
from tempfile import TemporaryDirectory
Expand All @@ -27,7 +28,7 @@
from ..version import __version__
from .import_utils import is_neuronx_available
from .patching import patch_everywhere
from .require_utils import requires_torch_neuronx
from .require_utils import requires_torch_neuronx, requires_torch_xla


if is_neuronx_available():
Expand Down Expand Up @@ -235,18 +236,20 @@ def hash(self):

@requires_torch_neuronx
@contextmanager
def hub_neuronx_cache(entry: Optional[ModelCacheEntry] = None):
def hub_neuronx_cache(entry: Optional[ModelCacheEntry] = None, cache_repo_id: Optional[str] = None):
"""A context manager to activate the Hugging Face Hub proxy compiler cache.

Args:
entry (`Optional[ModelCacheEntry]`, defaults to `None`):
An optional dataclass containing metadata associated with the model corresponding
to the cache session. Will create a dedicated entry in the cache registry.
cache_repo_id (`Optional[str]`, defaults to `None`):
The id of the cache repo to use to fetch the precompiled files.
"""

def hf_create_compile_cache(cache_url):
try:
return _create_hub_compile_cache_proxy(cache_url)
return _create_hub_compile_cache_proxy(cache_url, cache_repo_id=cache_repo_id)
except Exception as e:
logger.warning(f"Bypassing Hub cache because of the following error: {e}")
return create_compile_cache(cache_url)
Expand Down Expand Up @@ -274,6 +277,44 @@ def hf_create_compile_cache(cache_url):
patch_everywhere("create_compile_cache", create_compile_cache, "libneuronxla")


@requires_torch_neuronx
@requires_torch_xla
@contextmanager
def patch_neuron_cc_wrapper():
"""
Patches the `neuron_cc_wrapper` file to force it use our own version of it which essentially makes sure that it
uses our caching system.
"""

import torch_xla.core.xla_model as xm

def patch(restore: bool = False):
path = os.environ["PATH"]
main_dir = Path(path.split(":")[0])
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved
exists = "neuron_cc_wrapper" in os.listdir(main_dir)
if exists and not restore:
src = main_dir / "neuron_cc_wrapper"
dst = main_dir / "neuron_cc_wrapper_backup"
shutil.move(src, dst)
if restore:
src = main_dir / "neuron_cc_wrapper_backup"
dst = main_dir / "neuron_cc_wrapper"
else:
src = Path(__file__).parent / "neuron_cc_wrapper"
dst = main_dir / "neuron_cc_wrapper"
shutil.copy(src, dst)

try:
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved
if xm.get_ordinal() == 0:
patch()
xm.rendezvous("Patch neuron_cc_wrapper")
yield
finally:
if xm.get_ordinal() == 0:
patch(restore=True)
xm.rendezvous("Restore neuron_cc_wrapper")


@requires_torch_neuronx
def synchronize_hub_cache(cache_repo_id: Optional[str] = None):
"""Synchronize the neuronx compiler cache with the optimum-neuron hub cache.
Expand Down
9 changes: 9 additions & 0 deletions optimum/neuron/utils/neuron_cc_wrapper
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/usr/bin/python3
michaelbenayoun marked this conversation as resolved.
Show resolved Hide resolved
# -*- coding: utf-8 -*-
import re
import sys
from optimum.neuron.utils.optimum_neuron_cc_wrapper import main

if __name__ == '__main__':
sys.argv[0] = re.sub(r'(-script\.pyw|\.exe)?$', '', sys.argv[0])
sys.exit(main())
28 changes: 28 additions & 0 deletions optimum/neuron/utils/optimum_neuron_cc_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from libneuronxla.neuron_cc_wrapper import main as neuron_cc_wrapper_main

from .cache_utils import get_hf_hub_cache_repos
from .hub_neuronx_cache import hub_neuronx_cache


def main():
with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]):
return neuron_cc_wrapper_main()


if __name__ == "__main__":
main()
1 change: 0 additions & 1 deletion text-generation-inference/integration-tests/test_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ async def tgi_client(tgi_service):

@pytest.mark.asyncio
async def test_model_single_request(tgi_client):

# Greedy bounded without input
response = await tgi_client.generate(
"What is Deep Learning?",
Expand Down
Loading