Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not upload NeuronModelForCausalLM weights when they can be reconstructed from the hub #413

Merged
merged 6 commits into from
Jan 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions optimum/neuron/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,15 +648,15 @@ class NeuronModelForCausalLM(NeuronDecoderModel, GenerationMixin):

def __init__(
self,
model: torch.nn.Module,
config: "PretrainedConfig",
model_path: Union[str, "Path", "TemporaryDirectory"],
checkpoint_dir: Union[str, "Path", "TemporaryDirectory"],
compiled_dir: Optional[Union[str, "Path", "TemporaryDirectory"]] = None,
generation_config: Optional["GenerationConfig"] = None,
):
super().__init__(model, config, model_path, generation_config)
super().__init__(config, checkpoint_dir, compiled_dir=compiled_dir, generation_config=generation_config)
self.cur_len = 0
self.batch_size = model.config.batch_size
self.max_length = model.config.n_positions
self.batch_size = self.model.config.batch_size
self.max_length = self.model.config.n_positions
# The generate method from GenerationMixin expects the device attribute to be set
self.device = torch.device("cpu")

Expand Down
255 changes: 159 additions & 96 deletions optimum/neuron/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,19 @@
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Optional, Tuple, Union

import torch
from huggingface_hub import HfApi, HfFolder, snapshot_download
from huggingface_hub import HfApi, get_token, snapshot_download
from huggingface_hub.utils import is_google_colab
from transformers import AutoConfig, AutoModel, GenerationConfig

from ..exporters.neuron.model_configs import * # noqa: F403
from ..exporters.tasks import TasksManager
from ..modeling_base import OptimizedModel
from .utils import hub_neuronx_cache, is_transformers_neuronx_available
from .utils.require_utils import requires_transformers_neuronx
from .utils.version_utils import check_compiler_compatibility, get_neuronxcc_version


if is_transformers_neuronx_available():
from transformers_neuronx.module import PretrainedModel as NeuronxPretrainedModel
from transformers_neuronx.module import save_split


Expand Down Expand Up @@ -69,19 +68,25 @@ class NeuronDecoderModel(OptimizedModel):
CHECKPOINT_DIR = "checkpoint"
COMPILED_DIR = "compiled"

@requires_transformers_neuronx
def __init__(
self,
model: torch.nn.Module,
config: "PretrainedConfig",
model_path: Union[str, Path, TemporaryDirectory],
checkpoint_dir: Union[str, Path, TemporaryDirectory],
compiled_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
generation_config: Optional[GenerationConfig] = None,
):
if not is_transformers_neuronx_available() or not isinstance(model, NeuronxPretrainedModel):
raise ValueError("The source model must be a transformers_neuronx.PreTrainedModel.")
neuron_config = getattr(config, "neuron", None)
if neuron_config is None:
raise ValueError(
"The specified model is not a neuron model."
"Please convert your model to neuron format by passing export=True."
)

super().__init__(model, config)
self.model_path = model_path
self.checkpoint_dir = checkpoint_dir
self.compiled_dir = compiled_dir
if generation_config is None:
logger.info("Generation config file not found, using a generation config created from the model config.")
generation_config = GenerationConfig.from_model_config(config)
self.generation_config = generation_config
# Registers the NeuronModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
Expand All @@ -90,11 +95,45 @@ def __init__(
if hasattr(self.auto_model_class, "register"):
self.auto_model_class.register(AutoConfig, self.__class__)

# Evaluate the configuration passed during export
task = neuron_config["task"]
batch_size = neuron_config["batch_size"]
sequence_length = neuron_config["sequence_length"]
num_cores = neuron_config["num_cores"]
auto_cast_type = neuron_config["auto_cast_type"]

check_compiler_compatibility(neuron_config["compiler_type"], neuron_config["compiler_version"])

exporter = get_exporter(config, task)

# transformers-neuronx uses f32/f16 instead of fp32/fp16
auto_cast_type = auto_cast_type.replace("p", "")
checkpoint_path = checkpoint_dir.name if isinstance(checkpoint_dir, TemporaryDirectory) else checkpoint_dir
neuronx_model = exporter.neuronx_class.from_pretrained(
checkpoint_path,
batch_size=batch_size,
n_positions=sequence_length,
tp_degree=num_cores,
amp=auto_cast_type,
)

if compiled_dir is not None:
# Specify the path where compiled artifacts are stored before conversion
neuronx_model.load(compiled_dir)

# Compile the Neuron model (if present compiled artifacts will be reloaded instead of compiled)
neuron_cc_flags = os.environ.get("NEURON_CC_FLAGS", "")
os.environ["NEURON_CC_FLAGS"] = neuron_cc_flags + " --model-type=transformer"
with hub_neuronx_cache():
neuronx_model.to_neuron()
os.environ["NEURON_CC_FLAGS"] = neuron_cc_flags

super().__init__(neuronx_model, config)

@classmethod
def _from_transformers(
def _create_checkpoint(
cls,
model_id: str,
config: "PretrainedConfig",
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
force_download: bool = False,
Expand All @@ -103,18 +142,8 @@ def _from_transformers(
local_files_only: bool = False,
trust_remote_code: bool = False,
task: Optional[str] = None,
batch_size: Optional[int] = 1,
sequence_length: Optional[int] = None,
num_cores: Optional[int] = 2,
auto_cast_type: Optional[str] = "fp32",
**kwargs,
) -> "NeuronDecoderModel":
if not is_transformers_neuronx_available():
raise ModuleNotFoundError("The transformers_neuronx package is required to export the model.")

if task is None:
task = TasksManager.infer_task_from_model(cls.auto_model_class)

) -> TemporaryDirectory:
# Instantiate the transformers model checkpoint
model = TasksManager.get_model_from_task(
task=task,
Expand All @@ -135,6 +164,43 @@ def _from_transformers(
model.save_pretrained(
checkpoint_dir.name, save_function=save_split, safe_serialization=False, max_shard_size="10000GB"
)
return checkpoint_dir

@classmethod
@requires_transformers_neuronx
def _from_transformers(
cls,
model_id: str,
config: "PretrainedConfig",
use_auth_token: Optional[str] = None,
revision: Optional[str] = None,
task: Optional[str] = None,
batch_size: Optional[int] = 1,
sequence_length: Optional[int] = None,
num_cores: Optional[int] = 2,
auto_cast_type: Optional[str] = "fp32",
**kwargs,
) -> "NeuronDecoderModel":
if task is None:
task = TasksManager.infer_task_from_model(cls.auto_model_class)

# Instantiate the transformers model checkpoint
checkpoint_dir = cls._create_checkpoint(
model_id,
task=task,
revision=revision,
**kwargs,
)

if os.path.isdir(model_id):
checkpoint_id = None
checkpoint_revision = None
else:
checkpoint_id = model_id
# Get the exact checkpoint revision (SHA1)
api = HfApi(token=use_auth_token)
model_info = api.repo_info(model_id, revision=revision)
checkpoint_revision = model_info.sha

# If the sequence_length was not specified, deduce it from the model configuration
if sequence_length is None:
Expand All @@ -150,111 +216,98 @@ def _from_transformers(
"sequence_length": sequence_length,
"compiler_type": "neuronx-cc",
"compiler_version": get_neuronxcc_version(),
"checkpoint_id": checkpoint_id,
"checkpoint_revision": checkpoint_revision,
}

return cls._from_pretrained(checkpoint_dir, config)
# Try to reload the generation config (if any)
generation_config = None
try:
generation_config = GenerationConfig.from_pretrained(model_id)
except OSError:
pass

return cls(config, checkpoint_dir, generation_config=generation_config)

@classmethod
def _get_neuron_paths(
cls, model_dir: Union[str, Path, TemporaryDirectory], token: Optional[str] = None
) -> Tuple[str, str, str]:
if isinstance(model_dir, TemporaryDirectory):
model_path = model_dir.name
# We are in the middle of an export: the checkpoint is in the temporary model directory
checkpoint_path = model_path
# There are no compiled artifacts yet
compiled_path = None
else:
# The model has already been exported
if os.path.isdir(model_dir):
model_path = model_dir
else:
# Download the neuron model from the Hub
model_path = snapshot_download(model_dir, token=token)
# The checkpoint is in a subdirectory
checkpoint_path = os.path.join(model_path, cls.CHECKPOINT_DIR)
# So are the compiled artifacts
compiled_path = os.path.join(model_path, cls.COMPILED_DIR)
return model_path, checkpoint_path, compiled_path
def _get_neuron_dirs(cls, model_path: Union[str, Path]) -> Tuple[str, str]:
# The checkpoint is in a subdirectory
checkpoint_dir = os.path.join(model_path, cls.CHECKPOINT_DIR)
# So are the compiled artifacts
compiled_dir = os.path.join(model_path, cls.COMPILED_DIR)
return checkpoint_dir, compiled_dir

@classmethod
@requires_transformers_neuronx
def _from_pretrained(
cls,
model_id: Union[str, Path, TemporaryDirectory],
model_id: Union[str, Path],
config: "PretrainedConfig",
use_auth_token: Optional[str] = None,
revision: Optional[str] = None,
**kwargs,
) -> "NeuronDecoderModel":
# Verify we are actually trying to load a neuron model
neuron_config = getattr(config, "neuron", None)
if neuron_config is None:
raise ValueError(
"The specified directory does not contain a neuron model. "
"The specified directory does not contain a neuron model."
"Please convert your model to neuron format by passing export=True."
)

# Evaluate the configuration passed during export
task = neuron_config["task"]
batch_size = neuron_config["batch_size"]
sequence_length = neuron_config["sequence_length"]
num_cores = neuron_config["num_cores"]
auto_cast_type = neuron_config["auto_cast_type"]

check_compiler_compatibility(neuron_config["compiler_type"], neuron_config["compiler_version"])

exporter = get_exporter(config, task)

model_path, checkpoint_path, compiled_path = cls._get_neuron_paths(model_id, use_auth_token)

# transformers-neuronx uses f32/f16 instead of fp32/fp16
auto_cast_type = auto_cast_type.replace("p", "")
neuronx_model = exporter.neuronx_class.from_pretrained(
checkpoint_path,
batch_size=batch_size,
n_positions=sequence_length,
tp_degree=num_cores,
amp=auto_cast_type,
)

if compiled_path is not None:
# Specify the path where compiled artifacts are stored before conversion
neuronx_model.load(compiled_path)

# Compile the Neuron model (if present compiled artifacts will be reloaded instead of compiled)
neuron_cc_flags = os.environ.get("NEURON_CC_FLAGS", "")
os.environ["NEURON_CC_FLAGS"] = neuron_cc_flags + " --model-type=transformer"
with hub_neuronx_cache():
neuronx_model.to_neuron()
os.environ["NEURON_CC_FLAGS"] = neuron_cc_flags
model_path = model_id
if not os.path.isdir(model_id):
model_path = snapshot_download(model_id, token=use_auth_token, revision=revision)

checkpoint_dir, compiled_dir = cls._get_neuron_dirs(model_path)
if not os.path.isdir(checkpoint_dir):
# Try to recreate checkpoint from neuron config
task = neuron_config["task"]
checkpoint_id = neuron_config.get("checkpoint_id", None)
if checkpoint_id is None:
raise ValueError("Unable to fetch the neuron model weights files.")
checkpoint_revision = neuron_config["checkpoint_revision"]
checkpoint_dir = cls._create_checkpoint(
checkpoint_id,
task=task,
revision=checkpoint_revision,
use_auth_token=use_auth_token,
**kwargs,
)
assert os.path.isdir(compiled_dir)

# Try to reload the generation config (if any)
generation_config = None
try:
generation_config = GenerationConfig.from_pretrained(model_path)
generation_config = GenerationConfig.from_pretrained(model_id)
except OSError:
logger.info("Generation config file not found, using a generation config created from the model config.")
pass

return cls(neuronx_model, config, model_id, generation_config)
return cls(config, checkpoint_dir, compiled_dir=compiled_dir, generation_config=generation_config)

def forward(self, *args, **kwargs):
raise NotImplementedError()

def _save_pretrained(self, save_directory: Union[str, Path]):
_, src_chkpt_path, src_compiled_path = self._get_neuron_paths(self.model_path)
_, dst_chkpt_path, dst_compiled_path = self._get_neuron_paths(save_directory)

shutil.copytree(src_chkpt_path, dst_chkpt_path)

if src_compiled_path is None:
# The compiled model has never been serialized: do it now
dst_checkpoint_path, dst_compiled_path = self._get_neuron_dirs(save_directory)

def copy_dir_to_path(src_dir: Union[str, Path, TemporaryDirectory], dst_path: Union[str, Path]):
if isinstance(src_dir, TemporaryDirectory):
shutil.copytree(src_dir.name, dst_path)
elif not os.path.samefile(src_dir, dst_path):
os.symlink(dst_path, src_dir)

# Copy checkpoint directory (it always exists)
copy_dir_to_path(self.checkpoint_dir, dst_checkpoint_path)
self.checkpoint_dir = dst_checkpoint_path
# Save or create compiled directory
if self.compiled_dir is None:
# The compilation artifacts have never been saved, do it now
self.model.save(dst_compiled_path)
else:
shutil.copytree(src_compiled_path, dst_compiled_path)

if isinstance(self.model_path, TemporaryDirectory):
# Let temporary directory go out-of-scope to release disk space
self.model_path = save_directory

copy_dir_to_path(self.compiled_dir, dst_compiled_path)
self.compiled_dir = dst_compiled_path
self.generation_config.save_pretrained(save_directory)

def push_to_hub(
Expand All @@ -269,7 +322,7 @@ def push_to_hub(
if isinstance(use_auth_token, str):
huggingface_token = use_auth_token
elif use_auth_token:
huggingface_token = HfFolder.get_token()
huggingface_token = get_token()
else:
raise ValueError("You need to provide `use_auth_token` to be able to push to the hub")
api = HfApi(endpoint=endpoint)
Expand All @@ -285,6 +338,16 @@ def push_to_hub(
exist_ok=True,
private=private,
)
ignore_patterns = []
neuron_config = getattr(self.config, "neuron")
checkpoint_id = neuron_config.get("checkpoint_id", None)
if checkpoint_id is not None:
# Avoid uploading checkpoints when the original model is available on the hub
ignore_patterns = [self.CHECKPOINT_DIR + "/*"]
api.upload_folder(
repo_id=repository_id, folder_path=save_directory, token=huggingface_token, revision=revision
repo_id=repository_id,
folder_path=save_directory,
token=huggingface_token,
revision=revision,
ignore_patterns=ignore_patterns,
)
Loading
Loading