Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow prompt callback to generate_step #1133

Merged
merged 4 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llms/mlx_lm/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.

__version__ = "0.20.1"
__version__ = "0.20.2"
35 changes: 14 additions & 21 deletions llms/mlx_lm/cache_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import mlx.core as mx

from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import load, maybe_quantize_kv_cache
from .utils import generate_step, load

DEFAULT_QUANTIZED_KV_START = 5000

Expand Down Expand Up @@ -50,12 +50,6 @@ def setup_arg_parser():
action="store_true",
help="Use the default chat template",
)
parser.add_argument(
"--cache-limit-gb",
type=int,
default=None,
help="Set the MLX cache limit in GB",
)
parser.add_argument(
"--max-kv-size",
type=int,
Expand Down Expand Up @@ -99,9 +93,6 @@ def main():
parser = setup_arg_parser()
args = parser.parse_args()

if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)

# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.eos_token is not None:
Expand Down Expand Up @@ -144,26 +135,28 @@ def main():
y = mx.array(tokenizer.encode(prompt))

# Process the prompt
processed = 0
step_size = 512
start = time.time()
max_msg_len = 0
while y.size > 0:

model(y[:step_size][None], cache=cache)
mx.eval([c.state for c in cache])
mx.metal.clear_cache()
processed += min(y.size, step_size)
y = y[step_size:]
def callback(processed, total_tokens):
current = time.time()
speed = processed / (current - start)
msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)"
nonlocal max_msg_len
max_msg_len = max(max_msg_len, len(msg))
print(msg + " " * (max_msg_len - len(msg)), end="", flush=True)

maybe_quantize_kv_cache(
cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits
)
for _ in generate_step(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does an unnecessary sampling but I like very much that it removes so much duplicated logic.

Copy link
Member Author

@awni awni Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, it does the async eval on the first token / log probs. It should be negligible compared to the rest of the computation..

It's nice it simplifies the logic a bit.. mostly I changed this to have an example use of the callback and verify that it was working correctly.

Another option I considered is to split out a prefill_prompt API which generate_step can use and which cache_prompt.py can use. Maybe it's better, but so far we don't need to use it anywhere except here. If we end up needing to do the prompt computation in more places I will play around with that.

y,
model,
max_tokens=0,
prompt_cache=cache,
kv_bits=args.kv_bits,
kv_group_size=args.kv_group_size,
quantized_kv_start=args.quantized_kv_start,
prompt_progress_callback=callback,
):
pass

print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def setup_arg_parser():
)
parser.add_argument(
"--min-tokens-to-keep",
type=float,
type=int,
default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.",
)
Expand Down
44 changes: 26 additions & 18 deletions llms/mlx_lm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def generate_step(
prompt: mx.array,
model: nn.Module,
*,
max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
max_kv_size: Optional[int] = None,
Expand All @@ -191,6 +192,7 @@ def generate_step(
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[int, int]] = None,
temp: Optional[float] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = None,
Expand All @@ -204,21 +206,25 @@ def generate_step(
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
prefill_step_size (int): Step size for processing the prompt.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
prompt_prorgress_callback (Callable[int, int]): A call-back which takes the
prompt tokens processed so far and the total number of prompt tokens.

Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
Expand Down Expand Up @@ -253,6 +259,7 @@ def generate_step(
logits_processors = logits_processors or make_logits_processors(
None, repetition_penalty, repetition_context_size or 20
)
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
awni marked this conversation as resolved.
Show resolved Hide resolved

def _step(y):
with mx.stream(generation_stream):
Expand All @@ -275,9 +282,13 @@ def _step(y):
return y, logprobs.squeeze(0)

with mx.stream(generation_stream):
total_prompt_tokens = y.size
prompt_processed_tokens = 0
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
prompt_processed_tokens += prefill_step_size
y = y[prefill_step_size:]
mx.metal.clear_cache()

Expand All @@ -286,20 +297,25 @@ def _step(y):
mx.async_eval(y, logprobs)
n = 0
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)
if n != max_tokens:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y, next_logprobs)
if n == 0:
mx.eval(y)
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
if n == max_tokens:
break
yield y.item(), logprobs
if n % 256 == 0:
mx.metal.clear_cache()
n += 1
y, logprobs = next_y, next_logprobs
n += 1


def stream_generate(
model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: Union[str, mx.array, List[int]],
max_tokens: int = 100,
**kwargs,
) -> Generator[GenerationResponse, None, None]:
"""
Expand All @@ -309,7 +325,6 @@ def stream_generate(
model (nn.Module): The model to use for generation.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, mx.array, List[int]]): The input prompt string or integer tokens.
max_tokens (int): The maximum number of tokens. Default: ``100``.
kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details.

Expand All @@ -330,10 +345,7 @@ def stream_generate(
with wired_limit(model, [generation_stream]):
detokenizer.reset()
tic = time.perf_counter()
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt, model, **kwargs),
):
for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)):
if n == 0:
prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time
Expand All @@ -343,9 +355,6 @@ def stream_generate(

detokenizer.add_token(token)

if n == (max_tokens - 1):
break

yield GenerationResponse(
text=detokenizer.last_segment,
token=token,
Expand Down Expand Up @@ -385,7 +394,6 @@ def generate(
model (nn.Module): The language model.
tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (str): The string prompt.
max_tokens (int): The maximum number of tokens. Default: ``100``.
verbose (bool): If ``True``, print tokens and timing information.
Default: ``False``.
kwargs: The remaining options get passed to :func:`stream_generate`.
Expand Down
13 changes: 6 additions & 7 deletions llms/tests/test_prompt_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,20 @@ def test_save_load_mixed_cache(self):
def test_cache_with_generate(self):
model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
results = zip(range(4), generate_step(prompt, model))
toks, all_logits = zip(*(r[1] for r in results))
results = list(generate_step(prompt, model, max_tokens=4))
toks, all_logits = zip(*results)

prompt_cache = make_prompt_cache(model)
i = 0
for _, (tok, logits) in zip(
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
for tok, logits in generate_step(
prompt, model, prompt_cache=prompt_cache, max_tokens=2
):
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i]))
i += 1

for _, (tok, logits) in zip(
range(1),
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
for tok, logits in generate_step(
mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1
):
i += 1
self.assertEqual(tok, toks[i])
Expand Down