-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model transform docs #1665
Model transform docs #1665
Changes from all commits
c6cdb1c
0187a41
b4599f0
0b9b7a7
a6382c9
ff72ed5
8676b86
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
.. _model_transform_usage_label: | ||
|
||
===================== | ||
Multimodal Transforms | ||
===================== | ||
|
||
Multimodal model transforms apply model-specific data transforms to each modality and prepares :class:`~torchtune.data.Message` | ||
objects to be input into the model. torchtune currently supports text + image model transforms. | ||
These are intended to be drop-in replacements for tokenizers in multimodal datasets and support the standard | ||
``encode``, ``decode``, and ``tokenize_messages``. | ||
|
||
.. code-block:: python | ||
|
||
# torchtune.models.flamingo.FlamingoTransform | ||
class FlamingoTransform(ModelTokenizer, Transform): | ||
def __init__(...): | ||
# Text transform - standard tokenization | ||
self.tokenizer = llama3_tokenizer(...) | ||
# Image transforms | ||
self.transform_image = CLIPImageTransform(...) | ||
self.xattn_mask = VisionCrossAttentionMask(...) | ||
|
||
|
||
.. code-block:: python | ||
|
||
from torchtune.models.flamingo import FlamingoTransform | ||
from torchtune.data import Message | ||
from PIL import Image | ||
|
||
sample = { | ||
"messages": [ | ||
Message( | ||
role="user", | ||
content=[ | ||
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I love this entire file! very good job! any chance that the example here can have more than 1 image? |
||
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))}, | ||
{"type": "text", "content": "What is common in these two images?"}, | ||
], | ||
), | ||
Message( | ||
role="assistant", | ||
content="A robot is in both images.", | ||
), | ||
], | ||
} | ||
transform = FlamingoTransform( | ||
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model", | ||
tile_size=224, | ||
patch_size=14, | ||
) | ||
tokenized_dict = transform(sample) | ||
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False)) | ||
# '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n<|image|><|image|>What is common in these two images?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nA robot is in both images.<|eot_id|>' | ||
print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width) | ||
# torch.Size([4, 3, 224, 224]) | ||
|
||
|
||
Using model transforms | ||
---------------------- | ||
You can pass them into any multimodal dataset builder just as you would a model tokenizer. | ||
|
||
.. code-block:: python | ||
|
||
from torchtune.datasets.multimodal import the_cauldron_dataset | ||
from torchtune.models.flamingo import FlamingoTransform | ||
|
||
transform = FlamingoTransform( | ||
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model", | ||
tile_size=224, | ||
patch_size=14, | ||
) | ||
ds = the_cauldron_dataset( | ||
model_transform=transform, | ||
subset="ai2d", | ||
) | ||
tokenized_dict = ds[0] | ||
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False)) | ||
# <|begin_of_text|><|start_header_id|>user<|end_header_id|> | ||
# | ||
# <|image|>Question: What do respiration and combustion give out | ||
# Choices: | ||
# A. Oxygen | ||
# B. Carbon dioxide | ||
# C. Nitrogen | ||
# D. Heat | ||
# Answer with the letter.<|eot_id|><|start_header_id|>assistant<|end_header_id|> | ||
# | ||
# Answer: B<|eot_id|> | ||
print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width) | ||
# torch.Size([4, 3, 224, 224]) | ||
|
||
Creating model transforms | ||
------------------------- | ||
Model transforms are expected to process both text and images in the sample dictionary. | ||
Both should be contained in the ``"messages"`` field of the sample. | ||
|
||
The following methods are required on the model transform: | ||
|
||
- ``tokenize_messages`` | ||
- ``__call__`` | ||
|
||
.. code-block:: python | ||
|
||
from torchtune.modules.tokenizers import ModelTokenizer | ||
from torchtune.modules.transforms import Transform | ||
|
||
class MyMultimodalTransform(ModelTokenizer, Transform): | ||
def __init__(...): | ||
self.tokenizer = my_tokenizer_builder(...) | ||
self.transform_image = MyImageTransform(...) | ||
|
||
def tokenize_messages( | ||
self, | ||
messages: List[Message], | ||
add_eos: bool = True, | ||
) -> Tuple[List[int], List[bool]]: | ||
# Any other custom logic here | ||
... | ||
|
||
return self.tokenizer.tokenize_messages( | ||
messages=messages, | ||
add_eos=add_eos, | ||
) | ||
|
||
def __call__( | ||
self, sample: Mapping[str, Any], inference: bool = False | ||
) -> Mapping[str, Any]: | ||
# Expected input parameters for vision encoder | ||
encoder_input = {"images": [], "aspect_ratio": []} | ||
messages = sample["messages"] | ||
|
||
# Transform all images in sample | ||
for message in messages: | ||
for image in message.get_media(): | ||
out = self.transform_image({"image": image}, inference=inference) | ||
encoder_input["images"].append(out["image"]) | ||
encoder_input["aspect_ratio"].append(out["aspect_ratio"]) | ||
sample["encoder_input"] = encoder_input | ||
|
||
# Transform all text - returns same dictionary with additional keys "tokens" and "mask" | ||
sample = self.tokenizer(sample, inference=inference) | ||
|
||
return sample | ||
|
||
transform = MyMultimodalTransform(...) | ||
sample = { | ||
"messages": [ | ||
Message( | ||
role="user", | ||
content=[ | ||
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))}, | ||
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))}, | ||
{"type": "text", "content": "What is common in these two images?"}, | ||
], | ||
), | ||
Message( | ||
role="assistant", | ||
content="A robot is in both images.", | ||
), | ||
], | ||
} | ||
tokenized_dict = transform(sample) | ||
print(tokenized_dict) | ||
# {'encoder_input': {'images': ..., 'aspect_ratio': ...}, 'tokens': ..., 'mask': ...} | ||
|
||
|
||
Example model transforms | ||
------------------------ | ||
- Flamingo | ||
- :class:`~torchtune.models.flamingo.FlamingoTransform` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,7 @@ class FlamingoTransform(ModelTokenizer, Transform): | |
|
||
Args: | ||
path (str): Path to pretrained tiktoken tokenizer file. | ||
tile_size (int): Size of the tiles to divide the image into. Default 224. | ||
tile_size (int): Size of the tiles to divide the image into. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is that true? did we remove the default? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep, this is a required argument |
||
patch_size (int): Size of the patches used in the CLIP vision tranformer model. This is | ||
used to calculate the number of image embeddings per image. | ||
max_num_tiles (int): Only used if possible_resolutions is NOT given. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: fine to ignore, but should we import those for example completion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mm this is just to show a snippet of what a transform may look like, not intended to be runnable