-
Notifications
You must be signed in to change notification settings - Fork 9.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Server: fix seed for multiple slots #6835
Server: fix seed for multiple slots #6835
Conversation
9fa0876
to
d924e61
Compare
llama.h
Outdated
/// @details Randomly selects a token from the candidates based on their probabilities using a given pointer to a std::mt19937. | ||
LLAMA_API llama_token llama_sample_token_with_rng( | ||
struct llama_context * ctx, | ||
llama_token_data_array * candidates, | ||
void * rng); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not ok to pass a pointer to a C++ object. This is supposed to be a C API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. What do you think would be a good alternative way to fix this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it would make sense to add an API for saving/loading RNG state via char *
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can think of two alternatives:
- Keep one rng per sequence in
llama_context
, add a newllama_sample_token
function that receives a sequence id. This would also require also updating the state saving code. - Add a new
llama_sample_token
function that receives a value between 0 and 1 and samples based on that value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you think it would make sense to add an API for saving/loading RNG state via
char *
?
I wouldn't like it very much as a solution, it would complicate the API and it would be inefficient, even though in practice it is not likely to have a measurable effect on performance, the RNG state size can be significant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like the RNG state is suitable to be part of the llama_sampling_context
. Should we fix this issue within the scope of #5214 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem is that the internal state of the RNG is not a simple unsigned int like the seed but rather 19337 bits and the way to set RNG state via string stream is kind of cumbersome. Would it make sense to create a struct like llama_rng
that just packages a random number generator, add it to the sampling context, and pass it for sampling?
In any case, what is the intended scope and functionality of saving/loading the state of llama_context
? Should it include the sampling RNG state?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to create a struct like llama_rng that just packages a random number generator, add it to the sampling context, and pass it for sampling?
The struct llama_sampling_context
is basically that - it should contain the full state related to sampling, so it makes sense to also move the RNG state there. It will take some larger refactoring to do that, so that's why I proposed to do it later in the scope of #5214
I don't think the RNG state should be stored with the llama_context
state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The struct llama_sampling_context is basically that - it should contain the full state related to sampling, so it makes sense to also move the RNG state there.
What I meant specifically: on this PR I'm passing an instance of std::mt19337
via void *
but that's kind of a bad solution. Should I change this to a struct llama_rng
defined in llama.cpp
so that the sampling code can pass that instead of void *
? Or should this PR be shelved until the sampling code is overhauled? To be clear, I am not interested in doing this, my primary goal is to get a baseline reproducible behavior that I can test my implementation in #6828 against.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can move the llama_sample_token_with_rng()
call to the internal C++ API over here:
Lines 1073 to 1076 in 4e96a81
// Internal API to be implemented by llama.cpp and used by tests/benchmarks only | |
#ifdef LLAMA_API_INTERNAL | |
Add a comment that this it is a temporary workaround. In the future, we'll update the sampling API to use llama_sampling_context
as planned in #5214
Should the existing functions for sampling in |
d924e61
to
123eaf0
Compare
I just noticed, should |
I've added a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's ok to merge this workaround in view of having reproducible results. I'll try to prioritise improving the sampling API soon
On master the server currently does not produce reproducible results for a given seed. The problem is that all slots use a common RNG in
llama_context
. This PR fixes this by adding a separate RNG to eachllama_sampling_context
. The llama.cpp API is extended with a sampling function that provides an external RNG.This PR also adds the server test I originally wrote for #6828 . It fails previous to the other changes.