diff --git a/docs/source/api_ref_generation.rst b/docs/source/api_ref_generation.rst new file mode 100644 index 0000000000..f9fd3c3b42 --- /dev/null +++ b/docs/source/api_ref_generation.rst @@ -0,0 +1,13 @@ +.. _generation: + +==================== +torchtune.generation +==================== + +.. currentmodule:: torchtune.generation + +.. autosummary:: + :toctree: generated/ + :nosignatures: + + generate diff --git a/docs/source/api_ref_utilities.rst b/docs/source/api_ref_utilities.rst index 6729c85500..35b97d8bfb 100644 --- a/docs/source/api_ref_utilities.rst +++ b/docs/source/api_ref_utilities.rst @@ -16,5 +16,4 @@ Miscellaneous get_device get_logger - generate torch_version_ge diff --git a/docs/source/index.rst b/docs/source/index.rst index 6e3304cee4..248fa8541d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -143,6 +143,7 @@ torchtune tutorials. api_ref_config api_ref_data api_ref_datasets + api_ref_generation api_ref_models api_ref_modules api_ref_training diff --git a/recipes/eleuther_eval.py b/recipes/eleuther_eval.py index ffb5e37551..4b80e132b0 100644 --- a/recipes/eleuther_eval.py +++ b/recipes/eleuther_eval.py @@ -13,13 +13,12 @@ from omegaconf import DictConfig from torch import nn -from torchtune import config, training, utils +from torchtune import config, generation, training, utils from torchtune.data import left_pad_sequence from torchtune.modules import TransformerDecoder from torchtune.modules.tokenizers import ModelTokenizer from torchtune.recipe_interfaces import EvalRecipeInterface - logger = utils.get_logger("DEBUG") try: @@ -155,7 +154,7 @@ def _model_generate( "``do_sample`` for generation tasks is not supported yet in torchtune." ) - toks = utils.generate( + toks = generation.generate( self._model, context, max_generated_tokens=self.max_gen_toks, diff --git a/recipes/generate.py b/recipes/generate.py index 31ddd9dee4..9b235813ec 100644 --- a/recipes/generate.py +++ b/recipes/generate.py @@ -12,7 +12,7 @@ from omegaconf import DictConfig from torch import nn -from torchtune import config, training, utils +from torchtune import config, generation, training, utils from torchtune.config._utils import _get_component_from_path from torchtune.data import ChatFormat, InstructTemplate, Message @@ -147,10 +147,10 @@ def generate(self, cfg: DictConfig): if self._quantization_mode is not None: logger.info("Starting compilation to improve generation performance ...") custom_generate_next_token = torch.compile( - utils.generate_next_token, mode="max-autotune", fullgraph=True + generation.generate_next_token, mode="max-autotune", fullgraph=True ) t0 = time.perf_counter() - _ = utils.generate( + _ = generation.generate( model=self._model, prompt=prompt, max_generated_tokens=2, @@ -163,7 +163,7 @@ def generate(self, cfg: DictConfig): logger.info(f"Warmup run for quantized model takes: {t:.02f} sec") t0 = time.perf_counter() - generated_tokens = utils.generate( + generated_tokens = generate( model=self._model, prompt=prompt, max_generated_tokens=cfg.max_new_tokens, diff --git a/tests/torchtune/utils/test_generation.py b/tests/torchtune/generation/test_generation.py similarity index 96% rename from tests/torchtune/utils/test_generation.py rename to tests/torchtune/generation/test_generation.py index 15c4a336fa..671308cf39 100644 --- a/tests/torchtune/utils/test_generation.py +++ b/tests/torchtune/generation/test_generation.py @@ -8,10 +8,9 @@ import torch from tests.test_utils import fixed_init_model +from torchtune.generation._generation import generate, sample -from torchtune import utils from torchtune.models.llama2 import llama2 -from torchtune.utils._generation import sample class TestTextGenerate: @@ -119,7 +118,7 @@ def test_reproducibility(self, request, model1, model2, prompt): top_k = 100 torch.manual_seed(42) - outputs_first = utils.generate( + outputs_first = generate( model=model1, prompt=prompt, max_generated_tokens=10, @@ -128,7 +127,7 @@ def test_reproducibility(self, request, model1, model2, prompt): ) torch.manual_seed(42) - outputs_second = utils.generate( + outputs_second = generate( model=model2, prompt=prompt, max_generated_tokens=10, @@ -145,7 +144,7 @@ def test_batched_generate(self, generation_model_batched, prompt_tokens_batched) torch.manual_seed(42) - output = utils.generate( + output = generate( model=generation_model_batched, prompt=prompt_tokens_batched, max_generated_tokens=10, @@ -215,7 +214,7 @@ def test_stop_tokens(self, generation_model, prompt_tokens): torch.manual_seed(42) - outputs = utils.generate( + outputs = generate( model=generation_model, prompt=prompt_tokens, max_generated_tokens=10, @@ -242,7 +241,7 @@ def test_stop_tokens_batched(self, generation_model_batched, prompt_tokens_batch torch.manual_seed(42) - outputs = utils.generate( + outputs = generate( model=generation_model_batched, prompt=prompt_tokens_batched, max_generated_tokens=10, @@ -275,7 +274,7 @@ def test_stop_tokens_batched_uneven_stopping( torch.manual_seed(42) - outputs = utils.generate( + outputs = generate( model=generation_model_batched, prompt=prompt_tokens_batched, max_generated_tokens=10, diff --git a/tests/torchtune/modules/rlhf/test_generation.py b/tests/torchtune/modules/rlhf/test_generation.py index 2613a4f6b8..bfe8237c0c 100644 --- a/tests/torchtune/modules/rlhf/test_generation.py +++ b/tests/torchtune/modules/rlhf/test_generation.py @@ -8,9 +8,9 @@ import torch from tests.test_utils import fixed_init_model +from torchtune.generation._generation import sample from torchtune.models.llama2 import llama2 from torchtune.modules import rlhf -from torchtune.utils._generation import sample class TestGenerateNextTokenWithLogits: diff --git a/torchtune/__init__.py b/torchtune/__init__.py index 0b0511134f..a57f08eb64 100644 --- a/torchtune/__init__.py +++ b/torchtune/__init__.py @@ -23,6 +23,6 @@ """ ) from e -from torchtune import datasets, models, modules, utils +from torchtune import datasets, generation, models, modules, utils -__all__ = [datasets, models, modules, utils] +__all__ = [datasets, models, modules, utils, generation] diff --git a/torchtune/generation/__init__.py b/torchtune/generation/__init__.py new file mode 100644 index 0000000000..3111061df7 --- /dev/null +++ b/torchtune/generation/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._generation import generate, generate_next_token + +__all__ = ["generate", "generate_next_token"] diff --git a/torchtune/utils/_generation.py b/torchtune/generation/_generation.py similarity index 100% rename from torchtune/utils/_generation.py rename to torchtune/generation/_generation.py diff --git a/torchtune/training/__init__.py b/torchtune/training/__init__.py index a35e8ca89e..8d3aa78877 100644 --- a/torchtune/training/__init__.py +++ b/torchtune/training/__init__.py @@ -3,7 +3,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - from torchtune.training._distributed import ( contains_fsdp, FSDPPolicyType, diff --git a/torchtune/utils/__init__.py b/torchtune/utils/__init__.py index a4040dbdc8..46a64d4648 100644 --- a/torchtune/utils/__init__.py +++ b/torchtune/utils/__init__.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. from ._device import get_device -from ._generation import generate, generate_next_token from ._version import torch_version_ge from .logging import get_logger @@ -14,6 +13,4 @@ "get_device", "get_logger", "torch_version_ge", - "generate", - "generate_next_token", ]