Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move generation out of utils #1513

Merged
merged 6 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/source/api_ref_generation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.. _generation:

====================
torchtune.generation
====================

.. currentmodule:: torchtune.generation

.. autosummary::
:toctree: generated/
:nosignatures:

generate
1 change: 0 additions & 1 deletion docs/source/api_ref_utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ Miscellaneous

get_device
get_logger
generate
torch_version_ge
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ torchtune tutorials.
api_ref_config
api_ref_data
api_ref_datasets
api_ref_generation
api_ref_models
api_ref_modules
api_ref_training
Expand Down
5 changes: 2 additions & 3 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
from omegaconf import DictConfig

from torch import nn
from torchtune import config, training, utils
from torchtune import config, generation, training, utils
from torchtune.data import left_pad_sequence
from torchtune.modules import TransformerDecoder
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.recipe_interfaces import EvalRecipeInterface


logger = utils.get_logger("DEBUG")

try:
Expand Down Expand Up @@ -155,7 +154,7 @@ def _model_generate(
"``do_sample`` for generation tasks is not supported yet in torchtune."
)

toks = utils.generate(
toks = generation.generate(
self._model,
context,
max_generated_tokens=self.max_gen_toks,
Expand Down
8 changes: 4 additions & 4 deletions recipes/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from omegaconf import DictConfig
from torch import nn

from torchtune import config, training, utils
from torchtune import config, generation, training, utils
from torchtune.config._utils import _get_component_from_path
from torchtune.data import ChatFormat, InstructTemplate, Message

Expand Down Expand Up @@ -147,10 +147,10 @@ def generate(self, cfg: DictConfig):
if self._quantization_mode is not None:
logger.info("Starting compilation to improve generation performance ...")
custom_generate_next_token = torch.compile(
utils.generate_next_token, mode="max-autotune", fullgraph=True
generation.generate_next_token, mode="max-autotune", fullgraph=True
)
t0 = time.perf_counter()
_ = utils.generate(
_ = generation.generate(
model=self._model,
prompt=prompt,
max_generated_tokens=2,
Expand All @@ -163,7 +163,7 @@ def generate(self, cfg: DictConfig):
logger.info(f"Warmup run for quantized model takes: {t:.02f} sec")

t0 = time.perf_counter()
generated_tokens = utils.generate(
generated_tokens = generate(
model=self._model,
prompt=prompt,
max_generated_tokens=cfg.max_new_tokens,
Expand Down
2 changes: 1 addition & 1 deletion tests/torchtune/modules/rlhf/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

import torch
from tests.test_utils import fixed_init_model
from torchtune.generation._generation import sample
from torchtune.models.llama2 import llama2
from torchtune.modules import rlhf
from torchtune.utils._generation import sample


class TestGenerateNextTokenWithLogits:
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just need to move this file to tests/torchtune/generation :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry multitasking

Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import torch

from tests.test_utils import fixed_init_model
from torchtune.generation._generation import generate, sample

from torchtune import utils
from torchtune.models.llama2 import llama2
from torchtune.utils._generation import sample


class TestTextGenerate:
Expand Down Expand Up @@ -119,7 +118,7 @@ def test_reproducibility(self, request, model1, model2, prompt):
top_k = 100

torch.manual_seed(42)
outputs_first = utils.generate(
outputs_first = generate(
model=model1,
prompt=prompt,
max_generated_tokens=10,
Expand All @@ -128,7 +127,7 @@ def test_reproducibility(self, request, model1, model2, prompt):
)

torch.manual_seed(42)
outputs_second = utils.generate(
outputs_second = generate(
model=model2,
prompt=prompt,
max_generated_tokens=10,
Expand All @@ -145,7 +144,7 @@ def test_batched_generate(self, generation_model_batched, prompt_tokens_batched)

torch.manual_seed(42)

output = utils.generate(
output = generate(
model=generation_model_batched,
prompt=prompt_tokens_batched,
max_generated_tokens=10,
Expand Down Expand Up @@ -215,7 +214,7 @@ def test_stop_tokens(self, generation_model, prompt_tokens):

torch.manual_seed(42)

outputs = utils.generate(
outputs = generate(
model=generation_model,
prompt=prompt_tokens,
max_generated_tokens=10,
Expand All @@ -242,7 +241,7 @@ def test_stop_tokens_batched(self, generation_model_batched, prompt_tokens_batch

torch.manual_seed(42)

outputs = utils.generate(
outputs = generate(
model=generation_model_batched,
prompt=prompt_tokens_batched,
max_generated_tokens=10,
Expand Down Expand Up @@ -275,7 +274,7 @@ def test_stop_tokens_batched_uneven_stopping(

torch.manual_seed(42)

outputs = utils.generate(
outputs = generate(
model=generation_model_batched,
prompt=prompt_tokens_batched,
max_generated_tokens=10,
Expand Down
4 changes: 2 additions & 2 deletions torchtune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@
"""
) from e

from torchtune import datasets, models, modules, utils
from torchtune import datasets, generation, models, modules, utils

__all__ = [datasets, models, modules, utils]
__all__ = [datasets, models, modules, utils, generation]
9 changes: 9 additions & 0 deletions torchtune/generation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._generation import generate, generate_next_token

__all__ = ["generate", "generate_next_token"]
File renamed without changes.
1 change: 0 additions & 1 deletion torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtune.training._distributed import (
contains_fsdp,
FSDPPolicyType,
Expand Down
3 changes: 0 additions & 3 deletions torchtune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

from ._device import get_device
from ._generation import generate, generate_next_token

from ._version import torch_version_ge
from .logging import get_logger
Expand All @@ -14,6 +13,4 @@
"get_device",
"get_logger",
"torch_version_ge",
"generate",
"generate_next_token",
]
Loading