Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: [#1002] update chat tutorial so that it works as is #1004

Merged

Conversation

christobill
Copy link
Contributor

@christobill christobill commented May 19, 2024

Context

This PR:

  • update documentation

This PR addresses:
#1002

Changelog

Changes made in this PR:
Update chat tutorial so that it works as is, here is the link to this documentation
https://pytorch.org/torchtune/main/tutorials/chat.html

Copy link

pytorch-bot bot commented May 19, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1004

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 964315a with merge base d896253 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @christobill!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 19, 2024
@RdoubleA
Copy link
Contributor

Appreciate this update! Did we figure out whether the issue was using version 0.1.1 or setting the split parameter? In your case and in the tutorial of using a local dataset, I don't believe you need to set split='train', that is only for datasets on Hugging Face Hub

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me! Leaving findal merge approval to @RdoubleA though.

@christobill
Copy link
Contributor Author

christobill commented May 20, 2024

Thanks @joecummings !

@RdoubleA
TD;DR: the issue was coming from split='train and from the torchtune version

To be sure I restarted the whole process on a new machine:

Setup torchtune nightly:

pip install --pre torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu --no-cache-dir

tune download meta-llama/Meta-Llama-3-8B-Instruct \
  --output-dir /tmp/Meta-Llama-3-8B-Instruct \
  --hf-token ************************
  
export PYTHONPATH=$PYTHONPATH:/workspace

tune cp llama3/8B_qlora_single_device custom_config_qlora.yaml

my_module.py with code from https://pytorch.org/torchtune/main/tutorials/chat.html:

from torchtune.modules.tokenizers import Tokenizer
from torchtune.datasets import ChatDataset
from torchtune.data import Message
from typing import Mapping, Any, List

def message_converter(sample: Mapping[str, Any]) -> List[Message]:
    input_msg = sample["input"]
    output_msg = sample["output"]

    user_message = Message(
        role="user",
        content=input_msg,
        masked=True,  # Mask if not training on prompt
    )
    assistant_message = Message(
        role="assistant",
        content=output_msg,
        masked=False,
    )
    # A single turn conversation
    messages = [user_message, assistant_message]

    return messages

def custom_dataset(
    *,
    tokenizer: Tokenizer,
    max_seq_len: int = 2048,  # You can expose this if you want to experiment
) -> ChatDataset:

    return ChatDataset(
        tokenizer=tokenizer,
        # For local csv files, we specify "csv" as the source, just like in
        # load_dataset
        source="csv",
        convert_to_messages=message_converter,
        # Llama3 does not need a chat format
        chat_format=None,
        max_seq_len=max_seq_len,
        # To load a local file we specify it as data_files just like in
        # load_dataset
        data_files="your_file.csv",
    )

Changing custom_config_qlora.yaml dataset:

dataset:
  _component_: my_module.custom_dataset
  max_seq_len: 2048

your_file.csv:

"input","output"
"How do GPS receivers communicate with satellites?","The first thing to know is the communication is one-way..."
"What are the main components of a computer motherboard?","The motherboard is the main circuit board of a computer..."
"How does a microwave oven work?","Microwave ovens cook food using electromagnetic radiation..."
"What are the symptoms of a computer virus infection?","Symptoms of a computer virus infection can vary depending on the type of virus..."
"What is machine learning?","Machine learning is a subset of artificial intelligence..."
"How do airplanes fly?","Airplanes fly using the principles of aerodynamics..."
"What causes earthquakes?","Earthquakes are caused by the sudden release of energy..."
"How does the internet work?","The internet is a global network of interconnected computers..."
"What are the stages of the water cycle?","The water cycle consists of several stages..."
"What is the greenhouse effect?","The greenhouse effect is a natural process that warms the Earth's surface..."
"How does the human brain process information?","The human brain processes information through a complex network of neurons..."
"What is DNA?","DNA, or deoxyribonucleic acid, is a molecule that contains the genetic instructions for life..."
"How do plants make food?","Plants make food through a process called photosynthesis..."
"What is the theory of evolution?","The theory of evolution proposes that species change over time..."
"What causes tides?","Tides are primarily caused by the gravitational pull of the Moon and the Sun..."
"How does a camera work?","Cameras capture images by focusing light onto a photosensitive surface..."
"What is a black hole?","A black hole is a region of space where gravity is so strong that nothing, not even light, can escape..."
"How do vaccines work?","Vaccines work by stimulating the immune system to produce antibodies..."
"What is the difference between weather and climate?","Weather refers to short-term atmospheric conditions..."
"How do touchscreens work?","Touchscreens detect touch input using electrical signals..."
"What causes thunderstorms?","Thunderstorms are caused by the rapid upward movement of warm, moist air..."
"What is the purpose of the circulatory system?","The circulatory system transports oxygen, nutrients, and hormones throughout the body..."
"What are the properties of acids and bases?","Acids have a sour taste and turn blue litmus paper red..."
"How do cell phones work?","Cell phones work by transmitting and receiving radio signals..."
"What causes the seasons?","The tilt of the Earth's axis causes the seasons..."
"How do solar panels work?","Solar panels convert sunlight into electricity using photovoltaic cells..."
"What is the difference between a hurricane and a tornado?","Hurricanes are large, rotating storms that form over warm ocean waters..."
"How does the human digestive system work?","The human digestive system breaks down food into nutrients..."
"What is a tsunami?","A tsunami is a series of large ocean waves caused by underwater earthquakes or volcanic eruptions..."
"What are the different types of clouds?","Clouds are categorized based on their shape and altitude..."
"What is the structure of an atom?","Atoms consist of a nucleus surrounded by electrons..."
"How does the immune system work?","The immune system protects the body from harmful pathogens..."
"What is the Big Bang theory?","The Big Bang theory is the prevailing cosmological model..."
"What causes the phases of the moon?","The phases of the moon are caused by the relative positions of the Earth, moon, and Sun..."
"How does a refrigerator work?","Refrigerators work by removing heat from the interior..."
"What is the water table?","The water table is the level below which the ground is saturated with water..."
"What is the difference between speed and velocity?","Speed is a scalar quantity that measures how fast an object is moving..."
"How do 3D printers work?","3D printers create three-dimensional objects by laying down successive layers of material..."
"What are the primary colors of light?","The primary colors of light are red, green, and blue..."
"What is the difference between a comet and an asteroid?","Comets are icy bodies that orbit the Sun..."
"How do magnets work?","Magnets produce magnetic fields that exert forces on other magnets and magnetic materials..."
"What causes lightning?","Lightning is caused by the buildup of static electricity in clouds..."
"What is the difference between erosion and weathering?","Erosion is the process of transporting weathered material..."
"How does the human respiratory system work?","The human respiratory system delivers oxygen to the body and removes carbon dioxide..."
"What is the difference between a herbivore and a carnivore?","Herbivores primarily eat plants..."
"How do elevators work?","Elevators use electric motors to move between floors..."
"What causes the color of the sky?","The color of the sky is primarily due to scattering of sunlight by particles in the atmosphere..."
"What is the difference between mass and weight?","Mass is a measure of the amount of matter in an object..."
"How do lasers work?","Lasers emit coherent light through a process called stimulated emission..."
"What is the difference between an ecosystem and a habitat?","An ecosystem consists of all the living and nonliving components of a particular environment..."
"How do digital cameras work?","Digital cameras capture and store images as digital data..."
"What is the difference between weathering and erosion?","Weathering is the process of breaking down rocks..."
"How does a nuclear reactor work?","Nuclear reactors produce electricity by harnessing the heat released from nuclear reactions..."
"What causes the Northern Lights?","The Northern Lights, or auroras, are caused by charged particles from the Sun interacting with Earth's magnetic field..."
"How does a battery work?","Batteries convert chemical energy into electrical energy..."
"What are the layers of the Earth's atmosphere?","The Earth's atmosphere consists of several layers..."
"How do submarines work?","Submarines are underwater vessels that use ballast tanks to control buoyancy..."
"What is the difference between a vertebrate and an invertebrate?","Vertebrates have a backbone or spinal column..."
"How do rainbows form?","Rainbows form when sunlight is refracted, reflected, and dispersed by water droplets in the atmosphere..."
"What causes the Coriolis effect?","The Coriolis effect is caused by the rotation of the Earth..."
"How does a telescope work?","Telescopes gather and focus light to produce magnified images of distant objects..."
"What is the difference between a solution and a suspension?","A solution is a homogeneous mixture..."
"How does a car engine work?","Car engines convert chemical energy from fuel into mechanical energy..."
"What are the different types of energy?","Energy exists in various forms, including kinetic, potential, thermal, and electromagnetic..."
"How do radios work?","Radios receive and transmit radio waves to communicate over long distances..."
"What is the difference between an element and a compound?","An element consists of atoms with the same number of protons..."
"How does a wind turbine work?","Wind turbines convert kinetic energy from the wind into electrical energy..."
"What causes the formation of clouds?","Clouds form when warm, moist air rises and cools..."
"How does a toilet work?","Toilets use water and gravity to remove waste from the bowl..."
"What is the difference between velocity and acceleration?","Velocity is a vector quantity that includes both speed and direction..."

Resulting error;

Running the config:
tune run lora_finetune_single_device --config custom_config_qlora.yaml

0it [00:00, ?it/s]Traceback (most recent call last):
  File "/opt/conda/bin/tune", line 8, in <module>
    sys.exit(main())
  File "/opt/conda/lib/python3.10/site-packages/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/opt/conda/lib/python3.10/site-packages/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/opt/conda/lib/python3.10/site-packages/torchtune/_cli/run.py", line 179, in _run_cmd
    self._run_single_device(args)
  File "/opt/conda/lib/python3.10/site-packages/torchtune/_cli/run.py", line 93, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "/opt/conda/lib/python3.10/runpy.py", line 289, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/opt/conda/lib/python3.10/runpy.py", line 96, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/opt/conda/lib/python3.10/site-packages/recipes/lora_finetune_single_device.py", line 550, in <module>
    sys.exit(recipe_main())
  File "/opt/conda/lib/python3.10/site-packages/torchtune/config/_parse.py", line 50, in wrapper
    sys.exit(recipe_main(conf))
  File "/opt/conda/lib/python3.10/site-packages/recipes/lora_finetune_single_device.py", line 545, in recipe_main
    recipe.train()
  File "/opt/conda/lib/python3.10/site-packages/recipes/lora_finetune_single_device.py", line 450, in train
    for idx, batch in enumerate(self._dataloader):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.10/site-packages/torchtune/datasets/_chat.py", line 89, in __getitem__
    sample = self._data[index]
  File "/opt/conda/lib/python3.10/site-packages/datasets/dataset_dict.py", line 81, in __getitem__
    raise KeyError(
KeyError: "Invalid key: 0. Please first select a split. For example: `my_dataset_dictionary['train'][0]`. Available splits: ['train']"
0it [00:00, ?it/s]

Changing to split='train'

Changing my_module.py:

def custom_dataset(
    *,
    tokenizer: Tokenizer,
    max_seq_len: int = 2048,  # You can expose this if you want to experiment
) -> ChatDataset:

    return ChatDataset(
        tokenizer=tokenizer,
        # For local csv files, we specify "csv" as the source, just like in
        # load_dataset
        source="csv",
        convert_to_messages=message_converter,
        split="train",
        # Llama3 does not need a chat format
        chat_format=None,
        max_seq_len=max_seq_len,
        # To load a local file we specify it as data_files just like in
        # load_dataset
        data_files="your_file.csv",
    )

Second resulting error:

Traceback (most recent call last):
  File "/opt/conda/bin/tune", line 8, in <module>
    sys.exit(main())
  File "/opt/conda/lib/python3.10/site-packages/torchtune/_cli/tune.py", line 49, in main
    parser.run(args)
  File "/opt/conda/lib/python3.10/site-packages/torchtune/_cli/tune.py", line 43, in run
    args.func(args)
  File "/opt/conda/lib/python3.10/site-packages/torchtune/_cli/run.py", line 179, in _run_cmd
    self._run_single_device(args)
  File "/opt/conda/lib/python3.10/site-packages/torchtune/_cli/run.py", line 93, in _run_single_device
    runpy.run_path(str(args.recipe), run_name="__main__")
  File "/opt/conda/lib/python3.10/runpy.py", line 289, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "/opt/conda/lib/python3.10/runpy.py", line 96, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/opt/conda/lib/python3.10/site-packages/recipes/lora_finetune_single_device.py", line 550, in <module>
    sys.exit(recipe_main())
  File "/opt/conda/lib/python3.10/site-packages/torchtune/config/_parse.py", line 50, in wrapper
    sys.exit(recipe_main(conf))
  File "/opt/conda/lib/python3.10/site-packages/recipes/lora_finetune_single_device.py", line 545, in recipe_main
    recipe.train()
  File "/opt/conda/lib/python3.10/site-packages/recipes/lora_finetune_single_device.py", line 450, in train
    for idx, batch in enumerate(self._dataloader):
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.10/site-packages/torchtune/datasets/_chat.py", line 90, in __getitem__
    return self._prepare_sample(sample)
  File "/opt/conda/lib/python3.10/site-packages/torchtune/datasets/_chat.py", line 93, in _prepare_sample
    messages = self._convert_to_messages(sample, self.train_on_input)
TypeError: message_converter() takes 1 positional argument but 2 were given

Adding train_on_input to message_converter

Changing my_module.py:

def message_converter(sample: Mapping[str, Any], train_on_input: bool) -> List[Message]:
    input_msg = sample["input"]
    output_msg = sample["output"]

    user_message = Message(
        role="user",
        content=input_msg,
        masked=True,  # Mask if not training on prompt
    )
    assistant_message = Message(
        role="assistant",
        content=output_msg,
        masked=False,
    )
    # A single turn conversation
    messages = [user_message, assistant_message]

    return messages

Final result:

1|2|Loss: 2.3245208263397217: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:36<00:00, 18.01s/it]

INFO:torchtune.utils.logging:Model checkpoint of size 16.06 GB saved to /tmp/Meta-Llama-3-8B-Instruct/meta_model_0.pt
INFO:torchtune.utils.logging:Adapter checkpoint of size 0.04 GB saved to /tmp/Meta-Llama-3-8B-Instruct/adapter_0.pt
1|2|Loss: 2.3245208263397217: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [02:14<00:00, 67.21s/it]

It's working 🎉

@christobill christobill force-pushed the chore/1002-chat-tutorial-update branch from e25ce58 to 9d848ab Compare May 20, 2024 10:02
@joecummings joecummings requested a review from RdoubleA May 21, 2024 13:54
Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies for the delay on this. I was able to repro your setup and get it training. I left some comments

@@ -19,6 +19,7 @@ custom chat dataset for fine-tuning Llama3.
* Be familiar with :ref:`configuring datasets<dataset_tutorial_label>`
* Know how to :ref:`download Llama3 weights <llama3_label>`

Note: this tutorial works with a version of torchtune > 0.1.1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
Note: this tutorial works with a version of torchtune > 0.1.1
Note: this tutorial requires torchtune >= 0.1.1

@@ -252,7 +253,7 @@ the Message dataclass.

.. code-block:: python

def message_converter(sample: Mapping[str, Any]) -> List[Message]:
def message_converter(sample: Mapping[str, Any], train_on_input: bool) -> List[Message]:
Copy link
Contributor

@RdoubleA RdoubleA May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should incorporate the parameter in the logic below:

user_message = Message(
    role="user",
    content=input_msg,
    masked=not train_on_input,  # Mask if not training on prompt
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done ✔️

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

discussed offline on discord, I was doing your repro incorrectly. I didn't realize I wasn't using your config when I launched the run because I just copied it from the comments on the top of the file... anyhow, sorry about that. your changes make sense. I'll approve this, but if you could add those details we discussed that would be much appreciated :)

@@ -294,6 +295,7 @@ object.
# For local csv files, we specify "csv" as the source, just like in
# load_dataset
source="csv",
split="train",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also would be good to add a comment above this line saying specifying the default split of "train" is required for local files

also if you don't mind split="train" to the examples here as well:

Local and remote datasets

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done ✔️

@christobill christobill force-pushed the chore/1002-chat-tutorial-update branch 2 times, most recently from cb01a85 to 9134cf0 Compare May 25, 2024 11:29
@christobill christobill force-pushed the chore/1002-chat-tutorial-update branch from 9134cf0 to 964315a Compare May 25, 2024 12:00
@joecummings joecummings merged commit ad45c28 into pytorch:main May 28, 2024
29 checks passed
@christobill christobill deleted the chore/1002-chat-tutorial-update branch May 28, 2024 13:19
weifengpy pushed a commit to weifengpy/torchtune that referenced this pull request Jun 4, 2024
maximegmd pushed a commit to maximegmd/torchtune that referenced this pull request Jul 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants