Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support passing custom sampling function. #200

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
115 changes: 90 additions & 25 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
import torch
import numpy as np

from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils
from jetstream.engine import sampling_utils
from jetstream.engine import engine_api, tokenizer_api, tokenizer_pb2, token_utils, sampling_utils
import torch_xla2
from torch.utils import _pytree as pytree

Expand All @@ -44,6 +43,7 @@
from jetstream_pt.third_party.mixtral import config as mixtral_config, model as mixtral_model

from absl import flags
from collections.abc import Callable

FLAGS = flags.FLAGS

Expand All @@ -60,6 +60,7 @@ class Prefix:
token: jax.Array # [1, seqlen]
caches: List[Tuple[jax.Array, jax.Array]]
seq_len: int # true seqlen front pad
sampler: List[Any] | int # User defined Sampler


@struct.dataclass
Expand All @@ -73,8 +74,12 @@ class DecodeState:
current_position: int
lens: jax.Array # [batch_size, 1], the output token length
start: jax.Array # [batch_size, 1], the starting pos for each slot
input_pos: jax.Array # [batch_size, 1] input pos for each slot
input_pos: (
jax.Array
) # [batch_size, 1] total (prefill + decode) length for each slot
mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid
# The sampling function
samplers: Any


# NOTE model specific
Expand All @@ -93,7 +98,8 @@ def __init__(
self.pt_model = pt_model
self.env = env
self.default_dtype = jnp.bfloat16 if env.bf16_enable else jnp.float32
self.rng = jax.random.PRNGKey(0)
self.rng = jax.random.key(0)

self.weights = weights

self.y_sharding = env.sharding_by_axis(1)
Expand All @@ -119,6 +125,7 @@ def __init__(
donate_argnums=(1,),
out_shardings=(self.get_decode_state_sharding(), None),
)
# self.generate = self.generate_impl

if self.env.page_attention:
max_pages_per_sequence = (
Expand Down Expand Up @@ -168,6 +175,7 @@ def init_decode_state(
scalers = []
if self.env.quant_config.enable_kv_quantization:
scalers = [c.scalers() for c in caches_obj]

return DecodeState(
jnp.zeros((self.env.batch_size, 1), dtype=jnp.int32),
caches,
Expand All @@ -181,6 +189,7 @@ def init_decode_state(
float("-inf"),
dtype=self.default_dtype,
), # mask
None,
)

# pylint: disable-next=all
Expand Down Expand Up @@ -280,19 +289,42 @@ def _call_model_prefill(self, weights, tokens, input_indexes):
caches_res = [c.state() for c in caches]
return torchjax.from_torch((res, caches_res))

def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray:
# Temporarily disabled becuase handling per request sampling is not ready yet.
# @classmethod
# def _custom_sampling(self, logits, samplers) -> jnp.ndarray:
# if len(logits.shape) == 2:
# logits = jnp.expand_dims(logits, 0)

# logits = logits[:, -1]

# # Prefill and Generate have different batch size
# current_batch_size = logits.shape[0]

# idx = jnp.arange(current_batch_size)
# apply_sampler = lambda i, l: jax.lax.switch(i, samplers, l)
# apply_vmap = jax.vmap(apply_sampler, in_axes=(0, 0))
# return apply_vmap(idx, logits).reshape(current_batch_size, -1)


def _sampling(
self, logits: Any, algorithm, rng, temperature, topk, nucleus_topp
) -> jnp.ndarray:
if len(logits.shape) == 2:
logits = jnp.expand_dims(logits, 0)

logits = logits[:, -1]
current_batch_size = logits.shape[0]

return (
sampling_utils.sampling(
logits[:, -1],
self.rng,
self.env.sampling_algorithm,
self.env.topk,
self.env.nucleus_topp,
self.env.temperature,
logits=logits,
rng=rng,
algorithm=algorithm,
topk=topk,
nucleus_topp=nucleus_topp,
temperature=temperature,
)
.reshape(batch_size, -1)
.reshape(current_batch_size, -1)
.astype(jnp.int32)
)

Expand All @@ -301,7 +333,7 @@ def prefill(
*,
params: Any, # Weights
existing_prefix: Optional[Prefix] = None,
padded_tokens: PrefillInputs, # PrefillInputs[jax.Array],
padded_tokens: PrefillInputs, # PrefillInputs[jax.Array]
true_length: int,
sampler: Optional[Callable[[Any], Any]] = None,
) -> Tuple[Prefix, engine_api.ResultTokens]:
Expand All @@ -321,6 +353,7 @@ def prefill(
)
if len(logits.shape) == 3: # b, seqlen, num words
logits = logits[0] # seqlen, num words

if sampler:
token = sampler(logits[true_length - 1])
else:
Expand All @@ -332,6 +365,7 @@ def prefill(
self.env.nucleus_topp,
self.env.temperature,
)
token = jnp.reshape(token, (1,))
token_out = jnp.reshape(token, (1, 1))
data = jnp.concatenate(
[
Expand All @@ -357,7 +391,10 @@ def prefill(
# v, seq_len - true_length, true_length, axis=2))
# for k, v in updated_caches
# ]
return Prefix(token, updated_caches, true_length), result
return (
Prefix(token, updated_caches, true_length, sampler),
result,
)

def shrink_prefix(
self,
Expand Down Expand Up @@ -476,6 +513,8 @@ def insert(cache, scaler, new_entry, update_index):
caches.append((kcache, vcache))
scales.append((kscale, vscale))
lens = decode_state.lens.at[slot].set(1)

sampler = prefix.sampler if prefix.sampler else decode_state.samplers
return DecodeState(
tokens,
caches,
Expand All @@ -485,6 +524,7 @@ def insert(cache, scaler, new_entry, update_index):
start,
input_pos,
mask,
sampler,
)

# pylint: disable-next=all
Expand Down Expand Up @@ -569,6 +609,9 @@ def insert(cache, scaler, new_entry):
scales.append((kscale, vscale))

lens = decode_state.lens.at[slot].set(1)

sampler = prefix.sampler if prefix.sampler else decode_state.samplers

return DecodeState(
tokens,
caches,
Expand All @@ -578,6 +621,7 @@ def insert(cache, scaler, new_entry):
start,
input_pos,
mask,
sampler,
)

def _insert_page_attention(
Expand Down Expand Up @@ -613,6 +657,8 @@ def _insert_page_attention(
input_pos = decode_state.input_pos.at[slot].set(prefix.seq_len)
scales = None
lens = decode_state.lens.at[slot].set(1)

sampler = prefix.sampler if prefix.sampler else decode_state.samplers
return DecodeState(
tokens,
caches,
Expand All @@ -622,6 +668,7 @@ def _insert_page_attention(
start,
input_pos,
mask,
sampler,
)

def insert(
Expand Down Expand Up @@ -729,7 +776,9 @@ def false_comp(b, i, bk, start, end):
return b_next, i_next

def generate(
self, params: Any, decode_state: DecodeState, sampler=None
self,
params: Any,
decode_state: DecodeState,
) -> tuple[DecodeState, engine_api.ResultTokens]:
return (None, None)

Expand All @@ -752,7 +801,6 @@ def generate_impl(
self,
params: Any,
decode_state: DecodeState,
sampler=None,
page_token_indices=None,
) -> tuple[DecodeState, engine_api.ResultTokens]:
# seq_len = padded_tokens.shape[0]
Expand All @@ -764,12 +812,16 @@ def generate_impl(
else:
input_indexes = decode_state.input_pos

ragged_batch_index, ragged_block_index = (
self.precompute_ragged_block_indices(decode_state)
)
ragged_batch_index, ragged_block_index = ragged_batch_index.reshape(
(-1)
), ragged_block_index.reshape((-1))
# TODO(lancewang): Remove ragged index precomputation
# ragged_batch_index, ragged_block_index = (
# self.precompute_ragged_block_indices(decode_state)
# )
# ragged_batch_index, ragged_block_index = ragged_batch_index.reshape(
# (-1)
# ), ragged_block_index.reshape((-1))

ragged_batch_index = 0
ragged_block_index = 0

def update_mask():
if self.env.ring_buffer:
Expand Down Expand Up @@ -799,10 +851,20 @@ def update_mask():
# fill mask later, now use flash attention
mask = update_mask()

if sampler:
next_token = sampler(logits[:, -1])
# Temporarily disabled becuase handling per request sampling is not ready yet.
# next_token = self._custom_sampling(logits, decode_state.samplers)
if decode_state.samplers:
next_token = decode_state.samplers(logits)
else:
next_token = self._sampling(logits, self.env.batch_size)
next_token = self._sampling(
logits,
self.env.sampling_algorithm,
self.rng,
self.env.temperature,
self.env.topk,
self.env.nucleus_topp,
)

if self.env.ring_buffer:
input_pos = decode_state.input_pos + 1
lens = decode_state.lens + 1
Expand Down Expand Up @@ -844,6 +906,7 @@ def update_mask():
decode_state.start,
input_pos,
mask,
decode_state.samplers,
)
return new_decode_state, result_tokens

Expand Down Expand Up @@ -963,6 +1026,7 @@ def get_prefix_destination_sharding(self) -> Prefix:
if self.env.page_attention
else self.cache_sharding,
self.replicated,
self.replicated,
)

def get_decode_state_sharding(self) -> DecodeState:
Expand All @@ -976,6 +1040,7 @@ def get_decode_state_sharding(self) -> DecodeState:
self.replicated,
self.replicated,
self.replicated,
self.replicated,
)

def get_prefix_sequence_ddim(self) -> Any:
Expand Down
5 changes: 4 additions & 1 deletion jetstream_pt/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ class JetEngineEnvironmentData:
batch_size: int = 32 # batch size is generate step batch size
cache_sequence_length: int = 2048 # size of the cache.

quant_config: QuantizationConfig = QuantizationConfig()
# quant_config: QuantizationConfig = QuantizationConfig()
quant_config: QuantizationConfig = dataclasses.field(
default_factory=QuantizationConfig
)

model_type: str = "llama-2-13b" # this implies the model config

Expand Down
21 changes: 20 additions & 1 deletion run_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,29 @@
from absl import app
from jetstream.engine import token_utils
from jetstream_pt.config import FLAGS, create_engine_from_config_flags
from jetstream.engine import sampling_utils


# pylint: disable-next=all
def main(argv):

engine = create_engine_from_config_flags()

rng = jax.random.key(1)
temperature = 1
topk = 1
topp = 0.2

sampler = jax.tree_util.Partial(
sampling_utils.jittable_sample_topk_logits,
rng=rng,
temperature=temperature,
topk=topk,
)
# sampler = jax.tree_util.Partial(sampling_utils.jittable_sample_topp_logits, rng=rng, temperature=temperature, topp=topp)
# sampler = jax.tree_util.Partial(sampling_utils.jittable_sample_greedy_logits)
# sampler = jax.tree_util.Partial(sampling_utils.jittable_sample_weighted_logits, rng=rng, temperature=temperature)

start = time.perf_counter()
params = engine.load_params()
print("Load params ", time.perf_counter() - start)
Expand Down Expand Up @@ -77,7 +93,10 @@ def main(argv):
jax.profiler.start_trace(profiling_output)

prefill_result, _ = engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length
params=params,
padded_tokens=tokens,
true_length=true_length,
sampler=sampler
)
# pylint: disable-next=all
decode_state = engine.insert(prefill_result, decode_state, slot=slot)
Expand Down
6 changes: 4 additions & 2 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@


# pylint: disable-next=all
def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None):
def make_env_tiny(
bf16_enable=True, env_data_update_fn=lambda _: None, batch_size=1
):
torch_dtype = torch.bfloat16 if bf16_enable else torch.float32
torch.set_default_dtype(torch_dtype)
jax.config.update("jax_dynamic_shapes", False)
Expand All @@ -19,7 +21,7 @@ def make_env_tiny(bf16_enable=True, env_data_update_fn=lambda _: None):
environment_data.cache_sequence_length = 128
environment_data.bf16_enable = bf16_enable
environment_data.model_type = "llama-2-tiny"
environment_data.batch_size = 1
environment_data.batch_size = batch_size
environment_data.num_layers = config.n_layers
environment_data.cache_shape = (
1,
Expand Down
Loading