-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add static KV cache and test on Gemma-2B #4
Conversation
It was previously using n_positions sometimes, but that would not be available on some model configs.
c342e4f
to
5542841
Compare
if DBG_DEVICE env var is set, it will used to set the device for the model.
This will avoid loading the model twice.
Make compilation optional, it can be enabled with the environment variable DBG_COMPILE. This is because: 1. There are some models that produce bugs when the model is compiled. (notably gemma). 2. Models inference input params shapes change, triggering recompilation, leading to slow performance. 3. With the added xm.mark_step, performance is actually better when the model is not compiled. XLA builds a graph anyway, so performance is going to be good.
This is to reduce useless gradient calculations.
This will allow to handle passing different params in different model configurations later.
Some models, like Gemma and Llama, support static KV cache in transformers. For these, it is possible to use this feature, leading to much higher performance.
Also manually install accelerate to avoid memory issues when loading gemma.
5542841
to
27a2669
Compare
The test produces different results after some operations are being done in a slightly different order.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - a few comments more for further reflexion moving forward - Congratz!
self._id = id | ||
self._tokenizer = tokenizer | ||
self.clear() | ||
self._device = device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe let's do the conversion from str
to torch.device()
right away here to ensure we can fail fast if this device doesn't exist and avoid overhead later down the road?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The conversion does not make the check that the device is available. The only ways I found to check if the device is available is to invoke the torch_xla
api directly. I can add a check before mapping the model if you wish.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as discussed offline, adding such check is probably useless, given that the check will be done implicitly while mapping the model.
) | ||
# Update mask only if it was set previously | ||
if self._mask is not None: | ||
self._mask = torch.cat([self._mask, torch.tensor([1], device=self._device, dtype=self._mask.dtype)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe for later: Does this concatenate can be replaced by an inplace set from 0
to 1
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, I'll take a note.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
having said that: this is handled in a transparent way by models that use static cache, I guess they already do that inside the model.
What does this PR do?
This test adapts TGI server to better take advantage of Pytorch/XLA graphs. Relevant changes:
All this leads to performance general enhancements, so even if I added a test with a new model test run in 4m40s whereas before they where running in 5m21s