From 152d6b1e1e36d4595990d63e14f8f467b76a5a14 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Tue, 26 Nov 2024 13:53:03 -0500 Subject: [PATCH 1/2] Accept mx.array type for prompt argument for stream_generate --- llms/mlx_lm/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0e2f7af75..7329d626c 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -298,7 +298,7 @@ def _step(y): def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: Union[str, List[int]], + prompt: Union[str, mx.array, List[int]], max_tokens: int = 100, **kwargs, ) -> Generator[GenerationResponse, None, None]: @@ -308,7 +308,7 @@ def stream_generate( Args: model (nn.Module): The model to use for generation. tokenizer (PreTrainedTokenizer): The tokenizer. - prompt (Union[str, List[int]]): The input prompt string or integer tokens. + prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens. max_tokens (int): The maximum number of tokens. Default: ``100``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. @@ -320,7 +320,9 @@ def stream_generate( if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt)) + if not isinstance(prompt, mx.array): + prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt)) + detokenizer = tokenizer.detokenizer with wired_limit(model, [generation_stream]): From 83bc764e78be78e9b493ed8fdecb72420dd7eca8 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Tue, 26 Nov 2024 16:46:08 -0500 Subject: [PATCH 2/2] Fix formatting --- llms/mlx_lm/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 7329d626c..f439ca995 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -321,7 +321,9 @@ def stream_generate( tokenizer = TokenizerWrapper(tokenizer) if not isinstance(prompt, mx.array): - prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt)) + prompt = mx.array( + prompt if isinstance(prompt, list) else tokenizer.encode(prompt) + ) detokenizer = tokenizer.detokenizer