-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[1/n] Merged fine-tuning dataset: grammar + samsum #1234
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1234
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 75db622 with merge base 8519c35 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where are my deprecation warnings? :)
will be added in an upcoming PR... there's a lot stacked on this one :) |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1234 +/- ##
===========================================
+ Coverage 27.41% 69.60% +42.19%
===========================================
Files 233 238 +5
Lines 10591 10771 +180
===========================================
+ Hits 2903 7497 +4594
+ Misses 7688 3274 -4414 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
first pass, not thinking too hard about the design. The main thing that caught my attention was trying to understand what exactly FinetuneDataset is. The first line makes it sound like its a Base or a General purpose dataset that other datasets would use, but Chat and Instruct dont, so it is not immediately clear to me when to use one or the other. I hope it makes sense.
torchtune/datasets/_finetune.py
Outdated
|
||
class FinetuneDataset(Dataset): | ||
""" | ||
Dataset class for creating instruct, chat, tool, or multimodal datasets for fine-tuning. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you plan to use it in ChatDataset? At first, when reading this, I was thinking: "Ok, this is a base class or a general purpose dataset that other datasets will generally use", but then I checked ChatDataset and it is not there. Same for InstructDataset.
So, if the answer is no, I would be a bit confused about this description/location/naming.
If the answer is yes, maybe BaseDataset/GeneralDataset could be candidates?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. This will replace both Instruct and ChatDataset. We are essentially merging the two, while also adding support for multimodal.
I haven't thought of a better name than FinetuneDataset. Maybe TuneDataset
or TokenizedDataset
?
torchtune/data/_prompt_templates.py
Outdated
pass | ||
|
||
|
||
class CustomPromptTemplate(PromptTemplate): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm in favor of renaming this to PromptTemplate
and call the other thing PromptTemplateInterface
.
The "Custom" part of prompt template feels redundant. Like what's the difference between a CustomPromptTemplate and a regular PromptTemplate? I see here by looking at the code that it's b/c one is an interface, but that's not clear from the names. This should be evident before having to go to the docs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also use the Interface naming for protocols with our recipes so it makes sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that makes sense to me, although if a user needs a bit more custom behavior that's not offered by CustomPromptTemplate
, I wouldn't want them to accidentally inherit from PromptTemplate
instead of PromptTemplateInterface
.
torchtune/datasets/_finetune.py
Outdated
from torchtune.modules.transforms import Transform | ||
|
||
|
||
class FinetuneDataset(Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this actually only for SFT? If so, what makes it SFT specific?
And if it is, then should we rename to SFTDataset
????
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is SFT specific in the sense that it isn't Preference or TextCompletion. i.e., it covers the previous instruct, chat, and eventual multimodal datasets but it does not cover 1) chosen/rejected messaged in Preference, 2) using tokenizer.encode
directly instead of tokenizer.tokenize_messages
in TextCompletion
I'm cool with SFTDataset tho
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's get some other opinions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
milord? you asked for me?
SupervisedDataset? TaskDataset? SupervisedTaskDataset? STDataset?
otherwise +1 for something slightly more descriptive than FinetuneDataset
, SFTDataset
is OK too
torchtune/datasets/_finetune.py
Outdated
All datasets are formatted into :class:`~torchtune.data.Message`s because for | ||
fine-tuning, datasets can be considered as "conversations" with the model, | ||
or AI assistant. Thus, we can standardize all text content as messages in a conversation assigned to | ||
a :class:`~torchtune.data.Role`: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't rendering correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh no
torchtune/datasets/_finetune.py
Outdated
multimodal datasets requires processing the images in a way specific to the vision | ||
encoder being used by the model and is agnostic to the specific dataset. | ||
|
||
Tokenization is handled by the ``model_transform``. All :class:`~torchtune.modules.tokenizers.ModelTokenizer`s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:class:~torchtune.modules.tokenizers.ModelTokenizer
is not rendering correctly :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because we don't actually generate the doc for that anywhere perhaps?
Tokenization is handled by the ``model_transform``. All :class:`~torchtune.modules.tokenizers.ModelTokenizer`s | ||
can be treated as a ``model_transform`` since it uses the model-specific tokenizer to | ||
transform the list of messages outputted from the ``message_transform`` into tokens | ||
used by the model for training. Text-only datasets will simply pass the :class:`~torchtune.modules.tokenizers.ModelTokenizer` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not rendered.
torchtune/datasets/_finetune.py
Outdated
- Task-specific templates to gear models for a particular task that it will expect after training | ||
- Model-specific templates that are required whenever the model is prompted, such as the [INST] | ||
tags in Llama2 and in Mistral | ||
- Community standardized templates, such as :class:`~torchtune.data.ChatMLTemplate` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not rendered
torchtune/datasets/_finetune.py
Outdated
Args: | ||
source (str): path to dataset repository on Hugging Face. For local datasets, | ||
define source as the data file type (e.g. "json", "csv", "text") and pass | ||
in the filepath in ``data_files``. See Hugging Face's ``load_dataset`` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this a hyperlink?
torchtune/datasets/_finetune.py
Outdated
filter_fn (Optional[Callable]): callable used to filter the dataset prior to any pre-processing. See | ||
the Hugging Face `docs <https://huggingface.co/docs/datasets/v2.20.0/process#select-and-filter>`_ for more | ||
details. | ||
**load_dataset_kwargs (Dict[str, Any]): additional keyword arguments to pass to ``load_dataset``. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hyperlink
self, | ||
*, | ||
source: str, | ||
message_transform: Transform, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some reason, Transform is not being picked up. Is it included in the docs somewhere?
I'm guessing no...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am the King of Nits.
+ [10, 1, 6, -1] | ||
] | ||
ds = SFTDataset( | ||
source="iam/agoofy/goober", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks world-class - docs are really clear and easy to understand. I'm excited for this to land 💯
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't hate me for only getting around to reviewing this after you landed it. Anyways it looks great
@@ -130,113 +130,6 @@ def format( | |||
return prompt | |||
|
|||
|
|||
class GrammarErrorCorrectionTemplate(InstructTemplate): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😍
@@ -8,6 +8,7 @@ | |||
from torchtune.datasets._chat import chat_dataset, ChatDataset | |||
from torchtune.datasets._cnn_dailymail import cnn_dailymail_articles_dataset | |||
from torchtune.datasets._concat import ConcatDataset | |||
from torchtune.datasets._finetune import SFTDataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry to be that guy but imo it's a bit confusing that our canonical dataset does not match the pattern of all our other datasets (class name is the capitalized version of the filename). I'm good with the name SFTDataset
, but then maybe we should rename the file too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I missed that. will include that change in the subsequent PRs
mask = truncate(mask, max_seq_len, True) | ||
if self.max_seq_len: | ||
tokens = truncate(tokens, self.max_seq_len, self.eos_id) | ||
mask = truncate(mask, self.max_seq_len, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one doesn't need to add __call__
/ Transform
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh good call out
tokenizer (ModelTokenizer): Tokenizer used by the model that implements the ``tokenize_messages`` method. | ||
source (str): path string of dataset, anything supported by Hugging Face's `load_dataset`. | ||
model_transform (Transform): model specific transform to convert a list of messages | ||
output by the dataset to tokens. This will always be a :class:`~torchtune.modules.tokenizers.ModelTokenizer`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will always be a :class:
~torchtune.modules.tokenizers.ModelTokenizer
Just curious, why do we not type it as such then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mainly to be consistent with SFTDataset and the upcoming multimodal dataset builders. I'd be open to still calling it model_transform but typing it ModelTokenizer, but I don't know if we'd want to go all the way and just call this tokenizer
Default is False. | ||
column_map (Optional[Dict[str, str]]): a mapping to change the expected "input" | ||
and "output" column names to the actual column names in the dataset. Default is None, | ||
keeping the default "input" and "output" column names. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should give an example in the docstring
Context
What is the purpose of this PR? Is it to
As discussed in the RFC in #1186, we will merged instruct and chat datasets to the following unified pipeline that can better support multimodal:
message_transform
to createList[Message]
from dataset with full flexibility on columns, ad-hoc modifications, etc. For multimodal, additionally images are loaded from the pathprompt_template
as a optional way to add structured text around specific roles in the list of messagesmodel_transform
that takes the list of messages and tokenizes it. For multimodal, it will additionally apply model-specific image transforms to the images associated with the sampleFor ease of review, we will stage this as multiple moderate-sized PRs. This PR creates the unified dataset class, and refactors grammar and samsum to start off with. As a result, a few key changes were made:
FinetuneDataset
(not married to the name, best I can think of) class with the unified pipeline with associated unit testToInputOutputMessages
(open to better names) which provides a generic message transform that takes input column -> user message, output column -> assistant messageToInputOutputMessages
, add a default prompt template that can be changed, and use the newFinetuneDataset
classMessage
from_types.py
to_messages.py
. This file will now contain everything related toMessage
, including generic message transforms (like the ones in_converters.py
which will eventually migrate here)Transform
s for use in themodel_transform
argumentPromptTemplate
interface that merges functionality ofinstructTemplate
andChatFormat
. Refactored grammar and summarize templates to use a commonCustomPromptTemplate
class that takes in any template and formats a list of messages accordingly. It also covers the ability to specify a custom template from configs, which we were missing. More will be discussed in an upcoming tutorial update.Test plan
FinetuneDataset
, grammar, samsumToInputOutputMessages
pytest tests
pytest tests -m integration_test