Skip to content

Commit

Permalink
Move generation out of utils (#1513)
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored Sep 6, 2024
1 parent 277fbf8 commit 31a95a9
Show file tree
Hide file tree
Showing 12 changed files with 39 additions and 23 deletions.
13 changes: 13 additions & 0 deletions docs/source/api_ref_generation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.. _generation:

====================
torchtune.generation
====================

.. currentmodule:: torchtune.generation

.. autosummary::
:toctree: generated/
:nosignatures:

generate
1 change: 0 additions & 1 deletion docs/source/api_ref_utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ Miscellaneous

get_device
get_logger
generate
torch_version_ge
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ torchtune tutorials.
api_ref_config
api_ref_data
api_ref_datasets
api_ref_generation
api_ref_models
api_ref_modules
api_ref_training
Expand Down
5 changes: 2 additions & 3 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
from omegaconf import DictConfig

from torch import nn
from torchtune import config, training, utils
from torchtune import config, generation, training, utils
from torchtune.data import left_pad_sequence
from torchtune.modules import TransformerDecoder
from torchtune.modules.tokenizers import ModelTokenizer
from torchtune.recipe_interfaces import EvalRecipeInterface


logger = utils.get_logger("DEBUG")

try:
Expand Down Expand Up @@ -155,7 +154,7 @@ def _model_generate(
"``do_sample`` for generation tasks is not supported yet in torchtune."
)

toks = utils.generate(
toks = generation.generate(
self._model,
context,
max_generated_tokens=self.max_gen_toks,
Expand Down
8 changes: 4 additions & 4 deletions recipes/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from omegaconf import DictConfig
from torch import nn

from torchtune import config, training, utils
from torchtune import config, generation, training, utils
from torchtune.config._utils import _get_component_from_path
from torchtune.data import ChatFormat, InstructTemplate, Message

Expand Down Expand Up @@ -147,10 +147,10 @@ def generate(self, cfg: DictConfig):
if self._quantization_mode is not None:
logger.info("Starting compilation to improve generation performance ...")
custom_generate_next_token = torch.compile(
utils.generate_next_token, mode="max-autotune", fullgraph=True
generation.generate_next_token, mode="max-autotune", fullgraph=True
)
t0 = time.perf_counter()
_ = utils.generate(
_ = generation.generate(
model=self._model,
prompt=prompt,
max_generated_tokens=2,
Expand All @@ -163,7 +163,7 @@ def generate(self, cfg: DictConfig):
logger.info(f"Warmup run for quantized model takes: {t:.02f} sec")

t0 = time.perf_counter()
generated_tokens = utils.generate(
generated_tokens = generate(
model=self._model,
prompt=prompt,
max_generated_tokens=cfg.max_new_tokens,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import torch

from tests.test_utils import fixed_init_model
from torchtune.generation._generation import generate, sample

from torchtune import utils
from torchtune.models.llama2 import llama2
from torchtune.utils._generation import sample


class TestTextGenerate:
Expand Down Expand Up @@ -119,7 +118,7 @@ def test_reproducibility(self, request, model1, model2, prompt):
top_k = 100

torch.manual_seed(42)
outputs_first = utils.generate(
outputs_first = generate(
model=model1,
prompt=prompt,
max_generated_tokens=10,
Expand All @@ -128,7 +127,7 @@ def test_reproducibility(self, request, model1, model2, prompt):
)

torch.manual_seed(42)
outputs_second = utils.generate(
outputs_second = generate(
model=model2,
prompt=prompt,
max_generated_tokens=10,
Expand All @@ -145,7 +144,7 @@ def test_batched_generate(self, generation_model_batched, prompt_tokens_batched)

torch.manual_seed(42)

output = utils.generate(
output = generate(
model=generation_model_batched,
prompt=prompt_tokens_batched,
max_generated_tokens=10,
Expand Down Expand Up @@ -215,7 +214,7 @@ def test_stop_tokens(self, generation_model, prompt_tokens):

torch.manual_seed(42)

outputs = utils.generate(
outputs = generate(
model=generation_model,
prompt=prompt_tokens,
max_generated_tokens=10,
Expand All @@ -242,7 +241,7 @@ def test_stop_tokens_batched(self, generation_model_batched, prompt_tokens_batch

torch.manual_seed(42)

outputs = utils.generate(
outputs = generate(
model=generation_model_batched,
prompt=prompt_tokens_batched,
max_generated_tokens=10,
Expand Down Expand Up @@ -275,7 +274,7 @@ def test_stop_tokens_batched_uneven_stopping(

torch.manual_seed(42)

outputs = utils.generate(
outputs = generate(
model=generation_model_batched,
prompt=prompt_tokens_batched,
max_generated_tokens=10,
Expand Down
2 changes: 1 addition & 1 deletion tests/torchtune/modules/rlhf/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

import torch
from tests.test_utils import fixed_init_model
from torchtune.generation._generation import sample
from torchtune.models.llama2 import llama2
from torchtune.modules import rlhf
from torchtune.utils._generation import sample


class TestGenerateNextTokenWithLogits:
Expand Down
4 changes: 2 additions & 2 deletions torchtune/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@
"""
) from e

from torchtune import datasets, models, modules, utils
from torchtune import datasets, generation, models, modules, utils

__all__ = [datasets, models, modules, utils]
__all__ = [datasets, models, modules, utils, generation]
9 changes: 9 additions & 0 deletions torchtune/generation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._generation import generate, generate_next_token

__all__ = ["generate", "generate_next_token"]
File renamed without changes.
1 change: 0 additions & 1 deletion torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtune.training._distributed import (
contains_fsdp,
FSDPPolicyType,
Expand Down
3 changes: 0 additions & 3 deletions torchtune/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# LICENSE file in the root directory of this source tree.

from ._device import get_device
from ._generation import generate, generate_next_token

from ._version import torch_version_ge
from .logging import get_logger
Expand All @@ -14,6 +13,4 @@
"get_device",
"get_logger",
"torch_version_ge",
"generate",
"generate_next_token",
]

0 comments on commit 31a95a9

Please sign in to comment.