-
Notifications
You must be signed in to change notification settings - Fork 448
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
172 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
.. _model_transform_usage_label: | ||
|
||
===================== | ||
Multimodal Transforms | ||
===================== | ||
|
||
Multimodal model transforms apply model-specific data transforms to each modality and prepares :class:`~torchtune.data.Message` | ||
objects to be input into the model. torchtune currently supports text + image model transforms. | ||
These are intended to be drop-in replacements for tokenizers in multimodal datasets and support the standard | ||
``encode``, ``decode``, and ``tokenize_messages``. | ||
|
||
.. code-block:: python | ||
# torchtune.models.flamingo.FlamingoTransform | ||
class FlamingoTransform(ModelTokenizer, Transform): | ||
def __init__(...): | ||
# Text transform - standard tokenization | ||
self.tokenizer = llama3_tokenizer(...) | ||
# Image transforms | ||
self.transform_image = CLIPImageTransform(...) | ||
self.xattn_mask = VisionCrossAttentionMask(...) | ||
.. code-block:: python | ||
from torchtune.models.flamingo import FlamingoTransform | ||
from torchtune.data import Message | ||
from PIL import Image | ||
sample = { | ||
"messages": [ | ||
Message( | ||
role="user", | ||
content=[ | ||
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))}, | ||
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))}, | ||
{"type": "text", "content": "What is common in these two images?"}, | ||
], | ||
), | ||
Message( | ||
role="assistant", | ||
content="A robot is in both images.", | ||
), | ||
], | ||
} | ||
transform = FlamingoTransform( | ||
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model", | ||
tile_size=224, | ||
patch_size=14, | ||
) | ||
tokenized_dict = transform(sample) | ||
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False)) | ||
# '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n<|image|><|image|>What is common in these two images?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nA robot is in both images.<|eot_id|>' | ||
print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width) | ||
# torch.Size([4, 3, 224, 224]) | ||
Using model transforms | ||
---------------------- | ||
You can pass them into any multimodal dataset builder just as you would a model tokenizer. | ||
|
||
.. code-block:: python | ||
from torchtune.datasets.multimodal import the_cauldron_dataset | ||
from torchtune.models.flamingo import FlamingoTransform | ||
transform = FlamingoTransform( | ||
path="/tmp/Meta-Llama-3-8B-Instruct/original/tokenizer.model", | ||
tile_size=224, | ||
patch_size=14, | ||
) | ||
ds = the_cauldron_dataset( | ||
model_transform=transform, | ||
subset="ai2d", | ||
) | ||
tokenized_dict = ds[0] | ||
print(transform.decode(tokenized_dict["tokens"], skip_special_tokens=False)) | ||
# <|begin_of_text|><|start_header_id|>user<|end_header_id|> | ||
# | ||
# <|image|>Question: What do respiration and combustion give out | ||
# Choices: | ||
# A. Oxygen | ||
# B. Carbon dioxide | ||
# C. Nitrogen | ||
# D. Heat | ||
# Answer with the letter.<|eot_id|><|start_header_id|>assistant<|end_header_id|> | ||
# | ||
# Answer: B<|eot_id|> | ||
print(tokenized_dict["encoder_input"]["images"][0].shape) # (num_tiles, num_channels, tile_height, tile_width) | ||
# torch.Size([4, 3, 224, 224]) | ||
Creating model transforms | ||
------------------------- | ||
Model transforms are expected to process both text and images in the sample dictionary. | ||
Both should be contained in the ``"messages"`` field of the sample. | ||
|
||
The following methods are required on the model transform: | ||
|
||
- ``tokenize_messages`` | ||
- ``__call__`` | ||
|
||
.. code-block:: python | ||
from torchtune.modules.tokenizers import ModelTokenizer | ||
from torchtune.modules.transforms import Transform | ||
class MyMultimodalTransform(ModelTokenizer, Transform): | ||
def __init__(...): | ||
self.tokenizer = my_tokenizer_builder(...) | ||
self.transform_image = MyImageTransform(...) | ||
def tokenize_messages( | ||
self, | ||
messages: List[Message], | ||
add_eos: bool = True, | ||
) -> Tuple[List[int], List[bool]]: | ||
# Any other custom logic here | ||
... | ||
return self.tokenizer.tokenize_messages( | ||
messages=messages, | ||
add_eos=add_eos, | ||
) | ||
def __call__( | ||
self, sample: Mapping[str, Any], inference: bool = False | ||
) -> Mapping[str, Any]: | ||
# Expected input parameters for vision encoder | ||
encoder_input = {"images": [], "aspect_ratio": []} | ||
messages = sample["messages"] | ||
# Transform all images in sample | ||
for message in messages: | ||
for image in message.get_media(): | ||
out = self.transform_image({"image": image}, inference=inference) | ||
encoder_input["images"].append(out["image"]) | ||
encoder_input["aspect_ratio"].append(out["aspect_ratio"]) | ||
sample["encoder_input"] = encoder_input | ||
# Transform all text - returns same dictionary with additional keys "tokens" and "mask" | ||
sample = self.tokenizer(sample, inference=inference) | ||
return sample | ||
transform = MyMultimodalTransform(...) | ||
sample = { | ||
"messages": [ | ||
Message( | ||
role="user", | ||
content=[ | ||
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))}, | ||
{"type": "image", "content": Image.new(mode="RGB", size=(224, 224))}, | ||
{"type": "text", "content": "What is common in these two images?"}, | ||
], | ||
), | ||
Message( | ||
role="assistant", | ||
content="A robot is in both images.", | ||
), | ||
], | ||
} | ||
tokenized_dict = transform(sample) | ||
print(tokenized_dict) | ||
# {'encoder_input': {'images': ..., 'aspect_ratio': ...}, 'tokens': ..., 'mask': ...} | ||
Example model transforms | ||
------------------------ | ||
- Flamingo | ||
- :class:`~torchtune.models.flamingo.FlamingoTransform` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters