From 0f135396ae7fcb2bad407d6a41296ac84c0fb666 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 23 Nov 2024 11:47:06 -0800 Subject: [PATCH] Generation refactor: part 2 (#1099) * unify with stream_generate * fixes * nit * some cleanup, warnings, tests * fix test + faster min p + test * version --- llms/README.md | 11 +- llms/mlx_lm/_version.py | 2 +- llms/mlx_lm/chat.py | 10 +- llms/mlx_lm/examples/chat.py | 1 - llms/mlx_lm/examples/generate_response.py | 9 - llms/mlx_lm/generate.py | 46 +---- llms/mlx_lm/sample_utils.py | 22 +-- llms/mlx_lm/server.py | 42 ++--- llms/mlx_lm/tokenizer_utils.py | 11 +- llms/mlx_lm/utils.py | 203 ++++++++++++---------- llms/tests/test_generate.py | 3 +- llms/tests/test_sample_utils.py | 18 +- llms/tests/test_tokenizers.py | 3 +- 13 files changed, 184 insertions(+), 197 deletions(-) diff --git a/llms/README.md b/llms/README.md index eeb3ed6a0..60f68353e 100644 --- a/llms/README.md +++ b/llms/README.md @@ -61,7 +61,7 @@ prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) -response = generate(model, tokenizer, prompt=prompt, verbose=True) +text = generate(model, tokenizer, prompt=prompt, verbose=True) ``` To see a description of all the arguments you can do: @@ -100,8 +100,9 @@ To see a description of all the arguments you can do: #### Streaming -For streaming generation, use the `stream_generate` function. This returns a -generator object which streams the output text, token, and log probabilities. +For streaming generation, use the `stream_generate` function. This yields +a generation response object. + For example, ```python @@ -117,8 +118,8 @@ prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) -for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512): - print(t, end="", flush=True) +for response in stream_generate(model, tokenizer, prompt, max_tokens=512): + print(response.text, end="", flush=True) print() ``` diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 3811616f2..5168eee4e 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.19.3" +__version__ = "0.20.0" diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index c03056a6b..7795d8d79 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -5,7 +5,8 @@ import mlx.core as mx -from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache +from .models.cache import make_prompt_cache +from .sample_utils import make_sampler from .utils import load, stream_generate DEFAULT_TEMP = 0.0 @@ -74,16 +75,15 @@ def main(): prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - for response, *_ in stream_generate( + for response in stream_generate( model, tokenizer, prompt, args.max_tokens, - temp=args.temp, - top_p=args.top_p, + sampler=make_sampler(args.temp, args.top_p), prompt_cache=prompt_cache, ): - print(response, flush=True, end="") + print(response.text, flush=True, end="") print() diff --git a/llms/mlx_lm/examples/chat.py b/llms/mlx_lm/examples/chat.py index 3bf016884..c7512b3c9 100644 --- a/llms/mlx_lm/examples/chat.py +++ b/llms/mlx_lm/examples/chat.py @@ -42,7 +42,6 @@ tokenizer, prompt=prompt, verbose=True, - temp=0.0, prompt_cache=prompt_cache, ) diff --git a/llms/mlx_lm/examples/generate_response.py b/llms/mlx_lm/examples/generate_response.py index 257306171..e6535b476 100644 --- a/llms/mlx_lm/examples/generate_response.py +++ b/llms/mlx_lm/examples/generate_response.py @@ -23,14 +23,6 @@ # Specify if tokens and timing information will be printed verbose = True -# Some optional arguments for causal language model generation -generation_args = { - "temp": 0.7, - "repetition_penalty": 1.2, - "repetition_context_size": 20, - "top_p": 0.95, -} - # Generate a response with the specified settings response = generate( model=model, @@ -38,5 +30,4 @@ prompt=prompt, max_tokens=max_tokens, verbose=verbose, - **generation_args, ) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index de5c5719e..9e96fbdc0 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -7,6 +7,7 @@ import mlx.core as mx from .models.cache import QuantizedKVCache, load_prompt_cache +from .sample_utils import make_sampler from .utils import generate, load DEFAULT_PROMPT = "hello" @@ -97,11 +98,6 @@ def setup_arg_parser(): default=True, help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'", ) - parser.add_argument( - "--colorize", - action="store_true", - help="Colorize output based on T[0] probability", - ) parser.add_argument( "--max-kv-size", type=int, @@ -137,33 +133,6 @@ def setup_arg_parser(): return parser -def colorprint(color, s): - color_codes = { - "black": 30, - "red": 31, - "green": 32, - "yellow": 33, - "blue": 34, - "magenta": 35, - "cyan": 36, - "white": 39, - } - ccode = color_codes.get(color, 30) - print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True) - - -def colorprint_by_t0(s, t0): - if t0 > 0.95: - color = "white" - elif t0 > 0.70: - color = "green" - elif t0 > 0.30: - color = "yellow" - else: - color = "red" - colorprint(color, s) - - def main(): parser = setup_arg_parser() args = parser.parse_args() @@ -250,21 +219,14 @@ def main(): else: prompt = args.prompt - if args.colorize and not args.verbose: - raise ValueError("Cannot use --colorize with --verbose=False") - formatter = colorprint_by_t0 if args.colorize else None - + sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep) response = generate( model, tokenizer, prompt, - args.max_tokens, + max_tokens=args.max_tokens, verbose=args.verbose, - formatter=formatter, - temp=args.temp, - top_p=args.top_p, - min_p=args.min_p, - min_tokens_to_keep=args.min_tokens_to_keep, + sampler=sampler, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, kv_bits=args.kv_bits, diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index c27b52d85..f98684224 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -1,5 +1,6 @@ # Copyright © 2023-2024 Apple Inc. +import math from functools import partial from typing import Callable, Dict, Optional @@ -80,7 +81,7 @@ def logit_bias_processor(_, logits): @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def min_p_sampling( - logits: mx.array, + logprobs: mx.array, min_p: float, min_tokens_to_keep: int = 1, temperature=1.0, @@ -93,7 +94,7 @@ def min_p_sampling( aggressive given a very high-probability token. Args: - logits: The logits from the model's output. + logprobs: A vector of log probabilities. min_p (float): Minimum token probability. Typical values are in the 0.01-0.2 range, comparably selective as setting `top_p` in the 0.99-0.8 range. @@ -111,28 +112,27 @@ def min_p_sampling( ) # reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605 - # Softmax probabilities - probs = mx.softmax(logits * (1 / temperature), axis=-1) + logprobs = logprobs * (1 / temperature) # Indices sorted in decreasing order - sorted_indices = mx.argsort(-logits).squeeze(0) - sorted_probs = probs[..., sorted_indices] + sorted_indices = mx.argsort(-logprobs).squeeze(0) + sorted_logprobs = logprobs[..., sorted_indices] # Top probability - top_probs = probs[..., sorted_indices[0]] + top_logprobs = logprobs[..., sorted_indices[0]] # Calculate the min_p threshold - scaled_min_p = min_p * top_probs + scaled_min_p = top_logprobs + math.log(min_p) # Mask tokens that have a probability less than the scaled min_p - tokens_to_remove = sorted_probs < scaled_min_p + tokens_to_remove = sorted_logprobs < scaled_min_p tokens_to_remove[..., :min_tokens_to_keep] = False # Create pool of tokens with probability less than scaled min_p - selected_probs = mx.where(tokens_to_remove, 0, sorted_probs) + selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs) # Return sampled token - sorted_token = mx.random.categorical(mx.log(selected_probs)) + sorted_token = mx.random.categorical(selected_logprobs) return sorted_indices[sorted_token] diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index c1365b366..badc6dd37 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -27,6 +27,7 @@ from ._version import __version__ from .models.cache import make_prompt_cache +from .sample_utils import make_logits_processors, make_sampler from .utils import load, stream_generate @@ -464,25 +465,24 @@ def handle_completion( text = "" tic = time.perf_counter() - for n, (segment, token, logprobs) in enumerate( - stream_generate( - model=self.model, - tokenizer=self.tokenizer, - prompt=prompt, - max_tokens=self.max_tokens, - temp=self.temperature, - repetition_penalty=self.repetition_penalty, - repetition_context_size=self.repetition_context_size, - logit_bias=self.logit_bias, - prompt_cache=self.prompt_cache.cache, - ), + sampler = make_sampler(self.temperature) + logits_processors = make_logits_processors( + self.logit_bias, self.repetition_penalty, self.repetition_context_size + ) + for gen_response in stream_generate( + model=self.model, + tokenizer=self.tokenizer, + prompt=prompt, + max_tokens=self.max_tokens, + sampler=sampler, + logits_processors=logits_processors, + prompt_cache=self.prompt_cache.cache, ): - if n == 0: - prompt_time = time.perf_counter() - tic - tic = time.perf_counter() - + segment = gen_response.text text += segment logging.debug(text) + token = gen_response.token + logprobs = gen_response.logprobs tokens.append(token) if self.logprobs > 0: @@ -523,13 +523,9 @@ def handle_completion( self.prompt_cache.tokens.extend(tokens) - gen_time = time.perf_counter() - tic - prompt_tps = len(prompt) / prompt_time - gen_tps = len(tokens) / gen_time - peak_mem = mx.metal.get_peak_memory() / 1e9 - logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec") - logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec") - logging.debug(f"Peak memory: {peak_mem:.3f} GB") + logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec") + logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec") + logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB") if self.stream: response = self.generate_response(segment, finish_reason) diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 9d390733f..0fa41ac08 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -73,16 +73,16 @@ def __init__(self, tokenizer): def reset(self): self.offset = 0 - self._tokens = [] + self.tokens = [] self._text = "" self._current_tokens = [] self._current_text = "" def add_token(self, token): self._current_tokens.append(token) + self.tokens.append(token) def finalize(self): - self._tokens.extend(self._current_tokens) self._text += self._tokenizer.decode(self._current_tokens) self._current_tokens = [] self._current_text = "" @@ -97,16 +97,11 @@ def text(self): ): self._current_text = self._current_text[:-1] if self._current_text and self._current_text[-1] == "\n": - self._tokens.extend(self._current_tokens) self._text += self._current_text self._current_tokens.clear() self._current_text = "" return self._text + self._current_text - @property - def tokens(self): - return self._tokens - class SPMStreamingDetokenizer(StreamingDetokenizer): """A streaming detokenizer for SPM models. @@ -143,6 +138,7 @@ def _flush(self): self.text += text def add_token(self, token): + self.tokens.append(token) v = self.tokenmap[token] if v.startswith(self._sep): self._flush() @@ -200,6 +196,7 @@ def _maybe_trim_space(self, current_text): return current_text def add_token(self, token): + self.tokens.append(token) v = self.tokenmap[token] is_added = token in self._added_ids if is_added or self._byte_decoder[v[0]] == 32: diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index d4afd4281..496ae4fca 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -8,6 +8,7 @@ import logging import shutil import time +from dataclasses import dataclass from pathlib import Path from textwrap import dedent from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union @@ -44,6 +45,32 @@ def __init__(self, message): super().__init__(self.message) +@dataclass +class GenerationResponse: + """ + The output of :func:`stream_generate`. + + Args: + text (str): The next segment of decoded text. This can be an empty string. + token (int): The next token. + logprobs (mx.array): A vector of log probabilities. + prompt_tokens (int): The number of tokens in the prompt. + prompt_tps (float): The prompt processing tokens-per-second. + generation_tokens (int): The number of generated tokens. + generation_tps (float): The tokens-per-second for generation. + peak_memory (float): The peak memory used so far in GB. + """ + + text: str + token: int + logprobs: mx.array + prompt_tokens: int + prompt_tps: float + generation_tokens: int + generation_tps: float + peak_memory: float + + @contextlib.contextmanager def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None): """ @@ -155,20 +182,21 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ def generate_step( prompt: mx.array, model: nn.Module, - temp: float = 0.0, - repetition_penalty: Optional[float] = None, - repetition_context_size: Optional[int] = 20, - top_p: float = 1.0, - min_p: float = 0.0, - min_tokens_to_keep: int = 1, - prefill_step_size: int = 512, + *, + sampler: Optional[Callable[mx.array, mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, max_kv_size: Optional[int] = None, prompt_cache: Optional[Any] = None, - logit_bias: Optional[Dict[int, float]] = None, - logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + prefill_step_size: int = 512, kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, + temp: Optional[float] = None, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = None, + top_p: Optional[float] = None, + min_p: Optional[float] = None, + min_tokens_to_keep: Optional[int] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -176,32 +204,21 @@ def generate_step( Args: prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - temp (float): The temperature for sampling, if 0 the argmax is used. - Default: ``0``. - repetition_penalty (float, optional): The penalty factor for repeating - tokens. - repetition_context_size (int, optional): The number of tokens to - consider for repetition penalty. Default: ``20``. - top_p (float, optional): Nulceus sampling, higher means model considers - more less likely words. - min_p (float, optional): The minimum value (scaled by the top token's - probability) that a token probability must have to be considered. - min_tokens_to_keep (int, optional): Minimum number of tokens that cannot - be filtered by min_p sampling. prefill_step_size (int): Step size for processing the prompt. max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if provided, the cache will be updated in place. - logit_bias (dictionary, optional): Additive logit bias. + sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a + token from a vector of log probabilities. Default: ``None``. logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): - A list of functions that take tokens and logits and return the processed - logits. Default: ``None``. + A list of functions that take tokens and logits and return the processed + logits. Default: ``None``. kv_bits (int, optional): Number of bits to use for KV cache quantization. - None implies no cache quantization. Default: ``None``. + None implies no cache quantization. Default: ``None``. kv_group_size (int): Group size for KV cache quantization. Default: ``64``. quantized_kv_start (int): Step to begin using a quantized KV cache. - when ``kv_bits`` is non-None. Default: ``0``. + when ``kv_bits`` is non-None. Default: ``0``. Yields: Tuple[mx.array, mx.array]: One token and a vector of log probabilities. @@ -219,10 +236,22 @@ def generate_step( elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") - sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) - logits_processors = logits_processors or [] - logits_processors.extend( - make_logits_processors(logit_bias, repetition_penalty, repetition_context_size) + if temp is not None or top_p is not None or min_tokens_to_keep is not None: + print( + "[Warning] Specifying sampling arguments to ``generate_step`` is " + "deprecated. Pass in a ``sampler`` instead." + ) + if repetition_penalty is not None: + print( + "[Warning] Specifying ``repetition_penalty`` is deprecated. " + "Pass in ``logits_processors`` instead." + ) + + sampler = sampler or make_sampler( + temp or 0.0, top_p or 0.0, min_p or 0.0, min_tokens_to_keep or 1 + ) + logits_processors = logits_processors or make_logits_processors( + None, repetition_penalty, repetition_context_size or 20 ) def _step(y): @@ -290,17 +319,20 @@ def stream_generate( if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - prompt_tokens = mx.array( - prompt if isinstance(prompt, list) else tokenizer.encode(prompt) - ) + prompt = mx.array(prompt if isinstance(prompt, list) else tokenizer.encode(prompt)) detokenizer = tokenizer.detokenizer with wired_limit(model, [generation_stream]): detokenizer.reset() - for n, (token, logits) in zip( + tic = time.perf_counter() + for n, (token, logprobs) in zip( range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), + generate_step(prompt, model, **kwargs), ): + if n == 0: + prompt_time = time.perf_counter() - tic + prompt_tps = prompt.size / prompt_time + tic = time.perf_counter() if token == tokenizer.eos_token_id: break @@ -309,17 +341,34 @@ def stream_generate( if n == (max_tokens - 1): break - yield detokenizer.last_segment, token, logits + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + prompt_tokens=prompt.size, + prompt_tps=prompt_tps, + generation_tokens=n + 1, + generation_tps=(n + 1) / (time.perf_counter() - tic), + peak_memory=mx.metal.get_peak_memory() / 1e9, + ) detokenizer.finalize() - yield detokenizer.last_segment, token, logits + yield GenerationResponse( + text=detokenizer.last_segment, + token=token, + logprobs=logprobs, + prompt_tokens=prompt.size, + prompt_tps=prompt_tps, + generation_tokens=n + 1, + generation_tps=(n + 1) / (time.perf_counter() - tic), + peak_memory=mx.metal.get_peak_memory() / 1e9, + ) def generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], prompt: str, - max_tokens: int = 100, verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, @@ -334,64 +383,40 @@ def generate( max_tokens (int): The maximum number of tokens. Default: ``100``. verbose (bool): If ``True``, print tokens and timing information. Default: ``False``. - formatter (Optional[Callable]): A function which takes a token and a - probability and displays it. - kwargs: The remaining options get passed to :func:`generate_step`. - See :func:`generate_step` for more details. + kwargs: The remaining options get passed to :func:`stream_generate`. + See :func:`stream_generate` for more details. """ - if not isinstance(tokenizer, TokenizerWrapper): - tokenizer = TokenizerWrapper(tokenizer) - + if formatter is not None: + print( + "[Warning] Text formatting is deprecated and no longer used. " + "The argument will be removed in a future version." + ) if verbose: print("=" * 10) print("Prompt:", prompt) - prompt_tokens = mx.array(tokenizer.encode(prompt)) - detokenizer = tokenizer.detokenizer - - with wired_limit(model, [generation_stream]): - tic = time.perf_counter() - detokenizer.reset() - for n, (token, logprobs) in zip( - range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), - ): - if n == 0: - prompt_time = time.perf_counter() - tic - tic = time.perf_counter() - if token == tokenizer.eos_token_id: - break - detokenizer.add_token(token) - - if verbose: - if formatter: - # We have to finalize so that the prob corresponds to the last segment - detokenizer.finalize() - prob = mx.exp(logprobs[token]).item() - formatter(detokenizer.last_segment, prob) - else: - print(detokenizer.last_segment, end="", flush=True) - - token_count = n + 1 - detokenizer.finalize() - + text = "" + for response in stream_generate(model, tokenizer, prompt, **kwargs): if verbose: - gen_time = time.perf_counter() - tic - print(detokenizer.last_segment, flush=True) - print("=" * 10) - if token_count == 0: - print("No tokens generated for this prompt") - return - prompt_tps = prompt_tokens.size / prompt_time - gen_tps = (token_count - 1) / gen_time - print( - f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec" - ) - print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") - peak_mem = mx.metal.get_peak_memory() / 1e9 - print(f"Peak memory: {peak_mem:.3f} GB") + print(response.text, end="", flush=True) + text += response.text - return detokenizer.text + if verbose: + print() + print("=" * 10) + if len(text) == 0: + print("No text generated for this prompt") + return + print( + f"Prompt: {response.prompt_tokens} tokens, " + f"{response.prompt_tps:.3f} tokens-per-sec" + ) + print( + f"Generation: {response.generation_tokens} tokens, " + f"{response.generation_tps:.3f} tokens-per-sec" + ) + print(f"Peak memory: {response.peak_memory:.3f} GB") + return text def load_config(model_path: Path) -> dict: diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index e0a372a98..f23453943 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -2,6 +2,7 @@ import unittest +from mlx_lm.sample_utils import make_logits_processors from mlx_lm.utils import generate, load @@ -25,8 +26,8 @@ def test_generate_with_logit_bias(self): self.tokenizer, "hello", max_tokens=5, + logits_processors=make_logits_processors(logit_bias), verbose=False, - logit_bias=logit_bias, ) self.assertEqual(text, "!!!!!") diff --git a/llms/tests/test_sample_utils.py b/llms/tests/test_sample_utils.py index ec0e2cb74..ebc90ce83 100644 --- a/llms/tests/test_sample_utils.py +++ b/llms/tests/test_sample_utils.py @@ -1,10 +1,10 @@ import unittest import mlx.core as mx -from mlx_lm.sample_utils import top_p_sampling +from mlx_lm.sample_utils import min_p_sampling, top_p_sampling -class TestSamplingUtils(unittest.TestCase): +class TestSampleUtils(unittest.TestCase): def test_top_p_sampling(self): probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] logits = mx.log(probs) @@ -28,6 +28,20 @@ def test_top_p_sampling(self): token = top_p_sampling(logits, 0.95, temperature).item() self.assertTrue(token in (1, 2, 3)) + def test_min_p_sampling(self): + probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] + logits = mx.log(probs) + temperature = 1.0 + token = min_p_sampling(logits, 0.8) + self.assertEqual(token, 0) + + probs = mx.array([0.9, 0.0, 0.0, 0.1])[None] + logits = mx.log(probs) + temperature = 1.0 + for _ in range(5): + token = min_p_sampling(logits, 0.05) + self.assertTrue(token in (0, 3)) + if __name__ == "__main__": unittest.main() diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py index 9c30d51e4..db6b9f9e4 100644 --- a/llms/tests/test_tokenizers.py +++ b/llms/tests/test_tokenizers.py @@ -34,10 +34,11 @@ def check(tokens): detokenizer = tokenizer.detokenizer detokenizer.reset() text = "" - for t in tokens: + for e, t in enumerate(tokens): detokenizer.add_token(t) seg = detokenizer.last_segment text += seg + self.assertEqual(detokenizer.tokens, tokens[: e + 1]) detokenizer.finalize() text += detokenizer.last_segment self.assertEqual(text, expected_text)