-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Correct extra token, start preparing docker image for TGI/Jetstream Pt #93
Conversation
Jetstream's `generate` function returns input token as result token. The next token is instead available in the decode_state, so this change uses this instead.
Before, we could have an error is the seed was bigger than a 64 bit number.
This allows testing TGI images with Jetstream Pytorch.
This is required when there are no more tokens generated after prefill.
A new slot is created at each prefill request, and its selector is passed as argument to a jitted function. The problem is that each new slot has a new signature, even if the contents are the same. The solution is to wrap that in a singleton slot object for the prefill, so the compiler will always see the same object and stop recompiling.
16, | ||
32, | ||
64, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isnt it a bit too small?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
umh, it's actually aligned with Jetstream's buckets. But I just realized that's been extracted from the function, so I can import the constant now, I'll make the change.
@@ -55,6 +55,8 @@ def __init__( | |||
self.eos_token_ids = eos_token_ids | |||
self.pad_token_id = pad_token_id | |||
self.logits_warper = logits_warper | |||
# Seed needs to fit a 64-bit integer, so we modulo it in case is bigger (that can happen!) | |||
seed = seed % jnp.iinfo(jnp.int64).max |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we just store it as int64? Modulo op sounds very expensive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The seed comes from the TGI router and sometimes it comes as a crazy big number (bigger than int64). numpy and torch seem to silently do the operation, but jax
requires the seed to be an int64 and it raises an error if the number is higher. Let me know if you think about a "cheaper" alternative.
What does this PR do?
There was an issue related to the fact that decode used to return the token passed in the decode state as result, making TGI see as if the first token was returned twice. This is resolved by using the data in the decode state (the API should be stable, according to internal discussions with the Jetstream Pytorch team).
Also, the docker image is now capable of serving TGI using Jetstream Pytorch on the supported models.
Before submitting