Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model transform docs #1665

Merged
merged 7 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions docs/source/basics/model_transforms.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
.. _model_transform_usage_label:

=====================
Multimodal Transforms
=====================

Multimodal model transforms apply model-specific data transforms to each modality and prepares :class:`~torchtune.data.Message`
objects to be input into the model. torchtune currently supports text + image model transforms.
These are intended to be drop-in replacements for tokenizers in multimodal datasets and support the standard
``encode``, ``decode``, and ``tokenize_messages``.

.. code-block:: python

# torchtune.models.flamingo.FlamingoTransform
class FlamingoTransform(ModelTokenizer, Transform):
def __init__(...):
# Text transform - standard tokenization
self.tokenizer = llama3_tokenizer(...)
# Image transforms
self.transform_image = CLIPImageTransform(...)
self.xattn_mask = VisionCrossAttentionMask(...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: fine to ignore, but should we import those for example completion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mm this is just to show a snippet of what a transform may look like, not intended to be runnable



.. code-block:: python

from torchtune.models.flamingo import FlamingoTransform
from torchtune.data import Message
from PIL import Image

sample = {
"messages": [
Message(
role="user",
content=[
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love this entire file! very good job!

any chance that the example here can have more than 1 image?

{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
{"type": "text", "content": "What is common in these two images?"},
],
),
Message(
role="assistant",
content="A robot is in both images.",
),
],
}
transform = FlamingoTransform(
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
tile_size=224,
patch_size=14,
)
tokenized_dict = transform(sample)
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False))
# '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n<|image|><|image|>What is common in these two images?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nA robot is in both images.<|eot_id|>'
print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width)
# torch.Size([4, 3, 224, 224])


Using model transforms
----------------------
You can pass them into any multimodal dataset builder just as you would a model tokenizer.

.. code-block:: python

from torchtune.datasets.multimodal import the_cauldron_dataset
from torchtune.models.flamingo import FlamingoTransform

transform = FlamingoTransform(
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model",
tile_size=224,
patch_size=14,
)
ds = the_cauldron_dataset(
model_transform=transform,
subset="ai2d",
)
tokenized_dict = ds[0]
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False))
# <|begin_of_text|><|start_header_id|>user<|end_header_id|>
#
# <|image|>Question: What do respiration and combustion give out
# Choices:
# A. Oxygen
# B. Carbon dioxide
# C. Nitrogen
# D. Heat
# Answer with the letter.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
#
# Answer: B<|eot_id|>
print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width)
# torch.Size([4, 3, 224, 224])

Creating model transforms
-------------------------
Model transforms are expected to process both text and images in the sample dictionary.
Both should be contained in the ``"messages"`` field of the sample.

The following methods are required on the model transform:

- ``tokenize_messages``
- ``__call__``

.. code-block:: python

from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.modules.transforms import Transform

class MyMultimodalTransform(ModelTokenizer, Transform):
def __init__(...):
self.tokenizer = my_tokenizer_builder(...)
self.transform_image = MyImageTransform(...)

def tokenize_messages(
self,
messages: List[Message],
add_eos: bool = True,
) -> Tuple[List[int], List[bool]]:
# Any other custom logic here
...

return self.tokenizer.tokenize_messages(
messages=messages,
add_eos=add_eos,
)

def __call__(
self, sample: Mapping[str, Any], inference: bool = False
) -> Mapping[str, Any]:
# Expected input parameters for vision encoder
encoder_input = {"images": [], "aspect_ratio": []}
messages = sample["messages"]

# Transform all images in sample
for message in messages:
for image in message.get_media():
out = self.transform_image({"image": image}, inference=inference)
encoder_input["images"].append(out["image"])
encoder_input["aspect_ratio"].append(out["aspect_ratio"])
sample["encoder_input"] = encoder_input

# Transform all text - returns same dictionary with additional keys "tokens" and "mask"
sample = self.tokenizer(sample, inference=inference)

return sample

transform = MyMultimodalTransform(...)
sample = {
"messages": [
Message(
role="user",
content=[
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))},
{"type": "text", "content": "What is common in these two images?"},
],
),
Message(
role="assistant",
content="A robot is in both images.",
),
],
}
tokenized_dict = transform(sample)
print(tokenized_dict)
# {'encoder_input': {'images': ..., 'aspect_ratio': ...}, 'tokens': ..., 'mask': ...}


Example model transforms
------------------------
- Flamingo
- :class:`~torchtune.models.flamingo.FlamingoTransform`
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ torchtune tutorials.
basics/tokenizers
basics/prompt_templates
basics/preference_datasets
basics/model_transforms

.. toctree::
:glob:
Expand Down
2 changes: 1 addition & 1 deletion torchtune/models/flamingo/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FlamingoTransform(ModelTokenizer, Transform):

Args:
path (str): Path to pretrained tiktoken tokenizer file.
tile_size (int): Size of the tiles to divide the image into. Default 224.
tile_size (int): Size of the tiles to divide the image into.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that true? did we remove the default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, this is a required argument

patch_size (int): Size of the patches used in the CLIP vision tranformer model. This is
used to calculate the number of image embeddings per image.
max_num_tiles (int): Only used if possible_resolutions is NOT given.
Expand Down
Loading