Skip to content

Commit

Permalink
Generation refactor: part 2 (#1099)
Browse files Browse the repository at this point in the history
* unify with stream_generate

* fixes

* nit

* some cleanup, warnings, tests

* fix test + faster min p + test

* version
  • Loading branch information
awni authored Nov 23, 2024
1 parent 004eb4c commit 0f13539
Show file tree
Hide file tree
Showing 13 changed files with 184 additions and 197 deletions.
11 changes: 6 additions & 5 deletions llms/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

response = generate(model, tokenizer, prompt=prompt, verbose=True)
text = generate(model, tokenizer, prompt=prompt, verbose=True)
```

To see a description of all the arguments you can do:
Expand Down Expand Up @@ -100,8 +100,9 @@ To see a description of all the arguments you can do:

#### Streaming

For streaming generation, use the `stream_generate` function. This returns a
generator object which streams the output text, token, and log probabilities.
For streaming generation, use the `stream_generate` function. This yields
a generation response object.

For example,

```python
Expand All @@ -117,8 +118,8 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True)
for response in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(response.text, end="", flush=True)
print()
```

Expand Down
2 changes: 1 addition & 1 deletion llms/mlx_lm/_version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.

__version__ = "0.19.3"
__version__ = "0.20.0"
10 changes: 5 additions & 5 deletions llms/mlx_lm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import mlx.core as mx

from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
from .models.cache import make_prompt_cache
from .sample_utils import make_sampler
from .utils import load, stream_generate

DEFAULT_TEMP = 0.0
Expand Down Expand Up @@ -74,16 +75,15 @@ def main():
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for response, *_ in stream_generate(
for response in stream_generate(
model,
tokenizer,
prompt,
args.max_tokens,
temp=args.temp,
top_p=args.top_p,
sampler=make_sampler(args.temp, args.top_p),
prompt_cache=prompt_cache,
):
print(response, flush=True, end="")
print(response.text, flush=True, end="")
print()


Expand Down
1 change: 0 additions & 1 deletion llms/mlx_lm/examples/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@
tokenizer,
prompt=prompt,
verbose=True,
temp=0.0,
prompt_cache=prompt_cache,
)

Expand Down
9 changes: 0 additions & 9 deletions llms/mlx_lm/examples/generate_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,11 @@
# Specify if tokens and timing information will be printed
verbose = True

# Some optional arguments for causal language model generation
generation_args = {
"temp": 0.7,
"repetition_penalty": 1.2,
"repetition_context_size": 20,
"top_p": 0.95,
}

# Generate a response with the specified settings
response = generate(
model=model,
tokenizer=tokenizer,
prompt=prompt,
max_tokens=max_tokens,
verbose=verbose,
**generation_args,
)
46 changes: 4 additions & 42 deletions llms/mlx_lm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import mlx.core as mx

from .models.cache import QuantizedKVCache, load_prompt_cache
from .sample_utils import make_sampler
from .utils import generate, load

DEFAULT_PROMPT = "hello"
Expand Down Expand Up @@ -97,11 +98,6 @@ def setup_arg_parser():
default=True,
help="Log verbose output when 'True' or 'T' or only print the response when 'False' or 'F'",
)
parser.add_argument(
"--colorize",
action="store_true",
help="Colorize output based on T[0] probability",
)
parser.add_argument(
"--max-kv-size",
type=int,
Expand Down Expand Up @@ -137,33 +133,6 @@ def setup_arg_parser():
return parser


def colorprint(color, s):
color_codes = {
"black": 30,
"red": 31,
"green": 32,
"yellow": 33,
"blue": 34,
"magenta": 35,
"cyan": 36,
"white": 39,
}
ccode = color_codes.get(color, 30)
print(f"\033[1m\033[{ccode}m{s}\033[0m", end="", flush=True)


def colorprint_by_t0(s, t0):
if t0 > 0.95:
color = "white"
elif t0 > 0.70:
color = "green"
elif t0 > 0.30:
color = "yellow"
else:
color = "red"
colorprint(color, s)


def main():
parser = setup_arg_parser()
args = parser.parse_args()
Expand Down Expand Up @@ -250,21 +219,14 @@ def main():
else:
prompt = args.prompt

if args.colorize and not args.verbose:
raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None

sampler = make_sampler(args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
response = generate(
model,
tokenizer,
prompt,
args.max_tokens,
max_tokens=args.max_tokens,
verbose=args.verbose,
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
min_p=args.min_p,
min_tokens_to_keep=args.min_tokens_to_keep,
sampler=sampler,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits,
Expand Down
22 changes: 11 additions & 11 deletions llms/mlx_lm/sample_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright © 2023-2024 Apple Inc.

import math
from functools import partial
from typing import Callable, Dict, Optional

Expand Down Expand Up @@ -80,7 +81,7 @@ def logit_bias_processor(_, logits):

@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def min_p_sampling(
logits: mx.array,
logprobs: mx.array,
min_p: float,
min_tokens_to_keep: int = 1,
temperature=1.0,
Expand All @@ -93,7 +94,7 @@ def min_p_sampling(
aggressive given a very high-probability token.
Args:
logits: The logits from the model's output.
logprobs: A vector of log probabilities.
min_p (float): Minimum token probability. Typical values are in the
0.01-0.2 range, comparably selective as setting `top_p` in the
0.99-0.8 range.
Expand All @@ -111,28 +112,27 @@ def min_p_sampling(
)
# reference implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py#L531-L605

# Softmax probabilities
probs = mx.softmax(logits * (1 / temperature), axis=-1)
logprobs = logprobs * (1 / temperature)

# Indices sorted in decreasing order
sorted_indices = mx.argsort(-logits).squeeze(0)
sorted_probs = probs[..., sorted_indices]
sorted_indices = mx.argsort(-logprobs).squeeze(0)
sorted_logprobs = logprobs[..., sorted_indices]

# Top probability
top_probs = probs[..., sorted_indices[0]]
top_logprobs = logprobs[..., sorted_indices[0]]

# Calculate the min_p threshold
scaled_min_p = min_p * top_probs
scaled_min_p = top_logprobs + math.log(min_p)

# Mask tokens that have a probability less than the scaled min_p
tokens_to_remove = sorted_probs < scaled_min_p
tokens_to_remove = sorted_logprobs < scaled_min_p
tokens_to_remove[..., :min_tokens_to_keep] = False

# Create pool of tokens with probability less than scaled min_p
selected_probs = mx.where(tokens_to_remove, 0, sorted_probs)
selected_logprobs = mx.where(tokens_to_remove, -float("inf"), sorted_logprobs)

# Return sampled token
sorted_token = mx.random.categorical(mx.log(selected_probs))
sorted_token = mx.random.categorical(selected_logprobs)
return sorted_indices[sorted_token]


Expand Down
42 changes: 19 additions & 23 deletions llms/mlx_lm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from ._version import __version__
from .models.cache import make_prompt_cache
from .sample_utils import make_logits_processors, make_sampler
from .utils import load, stream_generate


Expand Down Expand Up @@ -464,25 +465,24 @@ def handle_completion(

text = ""
tic = time.perf_counter()
for n, (segment, token, logprobs) in enumerate(
stream_generate(
model=self.model,
tokenizer=self.tokenizer,
prompt=prompt,
max_tokens=self.max_tokens,
temp=self.temperature,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache,
),
sampler = make_sampler(self.temperature)
logits_processors = make_logits_processors(
self.logit_bias, self.repetition_penalty, self.repetition_context_size
)
for gen_response in stream_generate(
model=self.model,
tokenizer=self.tokenizer,
prompt=prompt,
max_tokens=self.max_tokens,
sampler=sampler,
logits_processors=logits_processors,
prompt_cache=self.prompt_cache.cache,
):
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()

segment = gen_response.text
text += segment
logging.debug(text)
token = gen_response.token
logprobs = gen_response.logprobs
tokens.append(token)

if self.logprobs > 0:
Expand Down Expand Up @@ -523,13 +523,9 @@ def handle_completion(

self.prompt_cache.tokens.extend(tokens)

gen_time = time.perf_counter() - tic
prompt_tps = len(prompt) / prompt_time
gen_tps = len(tokens) / gen_time
peak_mem = mx.metal.get_peak_memory() / 1e9
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
logging.debug(f"Prompt: {gen_response.prompt_tps:.3f} tokens-per-sec")
logging.debug(f"Generation: {gen_response.generation_tps:.3f} tokens-per-sec")
logging.debug(f"Peak memory: {gen_response.peak_memory:.3f} GB")

if self.stream:
response = self.generate_response(segment, finish_reason)
Expand Down
11 changes: 4 additions & 7 deletions llms/mlx_lm/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,16 @@ def __init__(self, tokenizer):

def reset(self):
self.offset = 0
self._tokens = []
self.tokens = []
self._text = ""
self._current_tokens = []
self._current_text = ""

def add_token(self, token):
self._current_tokens.append(token)
self.tokens.append(token)

def finalize(self):
self._tokens.extend(self._current_tokens)
self._text += self._tokenizer.decode(self._current_tokens)
self._current_tokens = []
self._current_text = ""
Expand All @@ -97,16 +97,11 @@ def text(self):
):
self._current_text = self._current_text[:-1]
if self._current_text and self._current_text[-1] == "\n":
self._tokens.extend(self._current_tokens)
self._text += self._current_text
self._current_tokens.clear()
self._current_text = ""
return self._text + self._current_text

@property
def tokens(self):
return self._tokens


class SPMStreamingDetokenizer(StreamingDetokenizer):
"""A streaming detokenizer for SPM models.
Expand Down Expand Up @@ -143,6 +138,7 @@ def _flush(self):
self.text += text

def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token]
if v.startswith(self._sep):
self._flush()
Expand Down Expand Up @@ -200,6 +196,7 @@ def _maybe_trim_space(self, current_text):
return current_text

def add_token(self, token):
self.tokens.append(token)
v = self.tokenmap[token]
is_added = token in self._added_ids
if is_added or self._byte_decoder[v[0]] == 32:
Expand Down
Loading

0 comments on commit 0f13539

Please sign in to comment.