Skip to content

Commit

Permalink
Allow prompt callback to generate_step (ml-explore#1133)
Browse files Browse the repository at this point in the history
* allow prompt callback and use in cache_prompt

* nit

* comments

* bump version
  • Loading branch information
awni authored and mokeddembillel committed Dec 16, 2024
1 parent a73de93 commit e08c470
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 48 deletions.
2 changes: 1 addition & 1 deletion llms/mlx_lm/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.

__version__ = "0.20.1"
__version__ = "0.20.2"
35 changes: 14 additions & 21 deletions llms/mlx_lm/cache_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import mlx.core as mx

from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import load, maybe_quantize_kv_cache
from .utils import generate_step, load

DEFAULT_QUANTIZED_KV_START = 5000

Expand Down Expand Up @@ -50,12 +50,6 @@ def setup_arg_parser():
action="store_true",
help="Use the default chat template",
)
parser.add_argument(
"--cache-limit-gb",
type=int,
default=None,
help="Set the MLX cache limit in GB",
)
parser.add_argument(
"--max-kv-size",
type=int,
Expand Down Expand Up @@ -99,9 +93,6 @@ def main():
parser = setup_arg_parser()
args = parser.parse_args()

if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)

# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.eos_token is not None:
Expand Down Expand Up @@ -144,26 +135,28 @@ def main():
y = mx.array(tokenizer.encode(prompt))

# Process the prompt
processed = 0
step_size = 512
start = time.time()
max_msg_len = 0
while y.size > 0:

model(y[:step_size][None], cache=cache)
mx.eval([c.state for c in cache])
mx.metal.clear_cache()
processed += min(y.size, step_size)
y = y[step_size:]
def callback(processed, total_tokens):
current = time.time()
speed = processed / (current - start)
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
nonlocal max_msg_len
max_msg_len = max(max_msg_len, len(msg))
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)

maybe_quantize_kv_cache(
cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits
)
for _ in generate_step(
y,
model,
max_tokens=0,
prompt_cache=cache,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
prompt_progress_callback=callback,
):
pass

print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def setup_arg_parser():
)
parser.add_argument(
"--min-tokens-to-keep",
type=float,
type=int,
default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.",
)
Expand Down
44 changes: 26 additions & 18 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def generate_step(
prompt: mx.array,
model: nn.Module,
*,
max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
max_kv_size: Optional[int] = None,
Expand All @@ -191,6 +192,7 @@ def generate_step(
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[int, int]] = None,
temp: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
Expand All @@ -204,21 +206,25 @@ def generate_step(
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
prefill_step_size (int): Step size for processing the prompt.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
prompt_prorgress_callback (Callable[int, int]): A call-back which takes the
prompt tokens processed so far and the total number of prompt tokens.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
Expand Down Expand Up @@ -253,6 +259,7 @@ def generate_step(
logits_processors = logits_processors or make_logits_processors(
None, repetition_penalty, repetition_context_size or 20
)
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)

def _step(y):
with mx.stream(generation_stream):
Expand All @@ -275,9 +282,13 @@ def _step(y):
return y, logprobs.squeeze(0)

with mx.stream(generation_stream):
total_prompt_tokens = y.size
prompt_processed_tokens = 0
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
prompt_processed_tokens += prefill_step_size
y = y[prefill_step_size:]
mx.metal.clear_cache()

Expand All @@ -286,20 +297,25 @@ def _step(y):
mx.async_eval(y, logprobs)
n = 0
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)
if n != max_tokens:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)
if n == 0:
mx.eval(y)
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
if n == max_tokens:
break
yield y.item(), logprobs
if n % 256 == 0:
mx.metal.clear_cache()
n += 1
y, logprobs = next_y, next_logprobs
n += 1


def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, mx.array, List[int]],
max_tokens: int = 100,
**kwargs,
) -> Generator[GenerationResponse, None, None]:
"""
Expand All @@ -309,7 +325,6 @@ def stream_generate(
model (nn.Module): The model to use for generation.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
max_tokens (int): The maximum number of tokens. Default: ``100``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.
Expand All @@ -330,10 +345,7 @@ def stream_generate(
with wired_limit(model, [generation_stream]):
detokenizer.reset()
tic = time.perf_counter()
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt, model, **kwargs),
):
for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)):
if n == 0:
prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time
Expand All @@ -343,9 +355,6 @@ def stream_generate(

detokenizer.add_token(token)

if n == (max_tokens - 1):
break

yield GenerationResponse(
text=detokenizer.last_segment,
token=token,
Expand Down Expand Up @@ -385,7 +394,6 @@ def generate(
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
kwargs: The remaining options get passed to :func:`stream_generate`.
Expand Down
13 changes: 6 additions & 7 deletions llms/tests/test_prompt_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,20 @@ def test_save_load_mixed_cache(self):
def test_cache_with_generate(self):
model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
results = zip(range(4), generate_step(prompt, model))
toks, all_logits = zip(*(r[1] for r in results))
results = list(generate_step(prompt, model, max_tokens=4))
toks, all_logits = zip(*results)

prompt_cache = make_prompt_cache(model)
i = 0
for _, (tok, logits) in zip(
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
for tok, logits in generate_step(
prompt, model, prompt_cache=prompt_cache, max_tokens=2
):
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i]))
i += 1

for _, (tok, logits) in zip(
range(1),
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
for tok, logits in generate_step(
mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1
):
i += 1
self.assertEqual(tok, toks[i])
Expand Down

0 comments on commit e08c470

Please sign in to comment.