Skip to content

Commit

Permalink
[train] TransformersPredictor: Add support for custom pipeline class (#…
Browse files Browse the repository at this point in the history
…36494)

Creating a `TransformersPredictor` with a custom pipeline class is currently broken: The model can't be automatically inferred from a path. This only works in the transformers pipeline. This PR adds support for this by adding additional parameters to `TransformersPredictor.from_checkpoint()` that will call `TransformersCheckpoint.get_model()` to retrieve the model, if specified.

Signed-off-by: Kai Fricke <kai@anyscale.com>
  • Loading branch information
krfricke authored Jun 21, 2023
1 parent 656fe07 commit 7ed5c6d
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def tf_get_gpus():

TRANSFORMERS_IMPORT_ERROR: Optional[ImportError] = None
try:
from transformers import PreTrainedModel, TFPreTrainedModel
from transformers.pipelines import Pipeline
from transformers.pipelines import pipeline as pipeline_factory
from transformers.pipelines.table_question_answering import (
Expand Down Expand Up @@ -102,6 +103,10 @@ def from_checkpoint(
checkpoint: Checkpoint,
*,
pipeline_cls: Optional[Type["Pipeline"]] = None,
model_cls: Optional[
Union[str, Type["PreTrainedModel"], Type["TFPreTrainedModel"]]
] = None,
pretrained_model_kwargs: Optional[dict] = None,
use_gpu: bool = False,
**pipeline_kwargs,
) -> "TransformersPredictor":
Expand All @@ -110,7 +115,7 @@ def from_checkpoint(
The checkpoint is expected to be a result of ``TransformersTrainer``.
Note that the Transformers ``pipeline`` used internally expects to
recieve raw text. If you have any Preprocessors in Checkpoint
receive raw text. If you have any Preprocessors in Checkpoint
that tokenize the data, remove them by calling
``Checkpoint.set_preprocessor(None)`` beforehand.
Expand All @@ -121,15 +126,21 @@ def from_checkpoint(
pipeline_cls: A ``transformers.pipelines.Pipeline`` class to use.
If not specified, will use the ``pipeline`` abstraction
wrapper.
model_cls: A ``transformers.PreTrainedModel`` class to create from
the checkpoint.
pretrained_model_kwargs: If set and a ``model_cls`` is provided, will
be passed to ``TransformersCheckpoint.get_model()``.
use_gpu: If set, the model will be moved to GPU on instantiation and
prediction happens on GPU.
**pipeline_kwargs: Any kwargs to pass to the pipeline
initialization. If ``pipeline`` is None, this must contain
the 'task' argument. Cannot contain 'model'. Can be used
initialization. If ``pipeline_cls`` is None, this must contain
the 'task' argument. Can be used
to override the tokenizer with 'tokenizer'. If ``use_gpu`` is
True, 'device' will be set to 0 by default, unless 'device_map' is
passed.
"""
from ray.train.huggingface import TransformersCheckpoint

if TRANSFORMERS_IMPORT_ERROR is not None:
raise TRANSFORMERS_IMPORT_ERROR

Expand All @@ -140,12 +151,39 @@ def from_checkpoint(
if use_gpu and "device_map" not in pipeline_kwargs:
# default to using the GPU with the first index
pipeline_kwargs.setdefault("device", 0)
pipeline_cls = pipeline_cls or pipeline_factory

model = None
if model_cls:
if not isinstance(checkpoint, TransformersCheckpoint):
raise ValueError(
"If `model_cls` is passed, the checkpoint has to be a "
"`TransformersCheckpoint`."
)
pretrained_model_kwargs = pretrained_model_kwargs or {}
model = checkpoint.get_model(model_cls, **pretrained_model_kwargs)

if pipeline_cls and model:
# Custom pipeline is passed and model was retrieved
pipeline = pipeline_cls(model, **pipeline_kwargs)
else:
# Custom pipeline class
if pipeline_cls:
pipeline_kwargs["pipeline_class"] = pipeline_cls

if not model:
# Infer model from checkpoint
with checkpoint.as_directory() as checkpoint_path:
# Tokenizer will be loaded automatically (no need to specify
# `tokenizer=checkpoint_path`)
pipeline = pipeline_factory(
model=checkpoint_path, **pipeline_kwargs
)
else:
# Use model with default pipeline
pipeline = pipeline_factory(model=model, **pipeline_kwargs)

preprocessor = checkpoint.get_preprocessor()
with checkpoint.as_directory() as checkpoint_path:
# Tokenizer will be loaded automatically (no need to specify
# `tokenizer=checkpoint_path`)
pipeline = pipeline_cls(model=checkpoint_path, **pipeline_kwargs)

return cls(
pipeline=pipeline,
preprocessor=preprocessor,
Expand Down
42 changes: 40 additions & 2 deletions python/ray/train/tests/test_transformers_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@
from ray.air.util.data_batch_conversion import _convert_pandas_to_batch_type
from ray.train.batch_predictor import BatchPredictor
from ray.train.predictor import TYPE_TO_ENUM
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.pipelines import pipeline
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
GPT2LMHeadModel,
)
from transformers.pipelines import pipeline, Pipeline


import ray
Expand All @@ -31,6 +36,20 @@
tokenizer_checkpoint = "hf-internal-testing/tiny-random-gpt2"


class CustomPipeline(Pipeline):
def _forward(self, input_tensors, **forward_parameters):
pass

def _sanitize_parameters(self, **pipeline_parameters):
return {}, {}, {}

def postprocess(self, model_outputs, **postprocess_parameters):
pass

def preprocess(self, input_, **preprocess_parameters):
pass


def test_repr(tmpdir):
predictor = TransformersPredictor()

Expand Down Expand Up @@ -90,6 +109,25 @@ def test_predict_no_preprocessor_no_training(tmpdir, ray_start_4_cpus):
assert len(predictions) == 3


@pytest.mark.parametrize("model_cls", [GPT2LMHeadModel, None])
def test_custom_pipeline(tmpdir, model_cls):
"""Create predictor from a custom pipeline class."""
model_config = AutoConfig.from_pretrained(model_checkpoint)
model = AutoModelForCausalLM.from_config(model_config)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint)
checkpoint = TransformersCheckpoint.from_model(model, tokenizer, path=tmpdir)

if model_cls:
kwargs = {}
else:
kwargs = {"task": "text-generation"}

predictor = TransformersPredictor.from_checkpoint(
checkpoint, pipeline_cls=CustomPipeline, model_cls=model_cls, **kwargs
)
assert isinstance(predictor.pipeline, CustomPipeline)


def create_checkpoint():
with tempfile.TemporaryDirectory() as tmpdir:
model_config = AutoConfig.from_pretrained(model_checkpoint)
Expand Down
2 changes: 1 addition & 1 deletion python/ray/tune/examples/tune_mnist_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,4 @@ def tune_mnist(num_training_iterations):
if args.smoke_test:
ray.init(num_cpus=4)

tune_mnist(num_training_iterations=5 if args.smoke_test else 300)
tune_mnist(num_training_iterations=2 if args.smoke_test else 300)

0 comments on commit 7ed5c6d

Please sign in to comment.