Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : move random seed generation to the samplers #9398

Merged
merged 5 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
std::string arg;
const std::string arg_prefix = "--";
gpt_params & params = ctx_arg.params;
gpt_sampler_params & sparams = params.sparams;

std::unordered_map<std::string, llama_arg *> arg_to_options;
for (auto & opt : ctx_arg.options) {
Expand Down Expand Up @@ -283,10 +282,6 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
params.kv_overrides.back().key[0] = 0;
}

if (sparams.seed == LLAMA_DEFAULT_SEED) {
sparams.seed = time(NULL);
}

return true;
}

Expand Down Expand Up @@ -909,7 +904,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
).set_sparam());
add_opt(llama_arg(
{"-s", "--seed"}, "SEED",
format("RNG seed (default: %d, use random seed for < 0)", params.sparams.seed),
format("RNG seed (default: %u, use random seed for %u)", params.sparams.seed, LLAMA_DEFAULT_SEED),
[](gpt_params & params, const std::string & value) {
params.sparams.seed = std::stoul(value);
}
Expand Down
4 changes: 4 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
return cur_p.data[cur_p.selected].id;
}

uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl) {
return llama_sampler_get_seed(gsmpl->chain);
}

// helpers

llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
Expand Down
2 changes: 2 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
//
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);

uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl);

// helpers

// access the internal list of current candidate tokens
Expand Down
2 changes: 0 additions & 2 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ int main(int argc, char ** argv) {

print_build_info();

LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);

llama_backend_init();
llama_numa_init(params.numa);

Expand Down
7 changes: 3 additions & 4 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,6 @@ int main(int argc, char ** argv) {

print_build_info();

LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);

LOG("%s: llama backend init\n", __func__);
llama_backend_init();
llama_numa_init(params.numa);
Expand Down Expand Up @@ -301,6 +299,9 @@ int main(int argc, char ** argv) {
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
}
}
smpl = gpt_sampler_init(model, sparams);

LOG_TEE("sampling seed: %u\n", gpt_sampler_get_seed(smpl));
LOG_TEE("sampling: \n%s\n", sparams.print().c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");
Expand Down Expand Up @@ -340,8 +341,6 @@ int main(int argc, char ** argv) {

std::vector<llama_token> embd;

smpl = gpt_sampler_init(model, sparams);

while (n_remain != 0 || params.interactive) {
// predict
if (!embd.empty()) {
Expand Down
6 changes: 3 additions & 3 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ int main(int argc, char ** argv) {

print_build_info();

LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);

LOG("%s: llama backend init\n", __func__);
llama_backend_init();
llama_numa_init(params.numa);
Expand Down Expand Up @@ -470,8 +468,10 @@ int main(int argc, char ** argv) {
exit(1);
}

LOG_TEE("sampling seed: %u\n", gpt_sampler_get_seed(smpl));
LOG_TEE("sampling params: \n%s\n", sparams.print().c_str());
LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());
LOG_TEE("sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());

LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);

// group-attention state
Expand Down
2 changes: 0 additions & 2 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2007,8 +2007,6 @@ int main(int argc, char ** argv) {

print_build_info();

LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);

llama_backend_init();
llama_numa_init(params.numa);

Expand Down
1 change: 1 addition & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,7 @@ struct server_context {
{"n_predict", slot.n_predict}, // Server configured n_predict
{"model", params.model_alias},
{"seed", slot.sparams.seed},
{"seed_cur", slot.smpl ? gpt_sampler_get_seed(slot.smpl) : 0},
{"temperature", slot.sparams.temp},
{"dynatemp_range", slot.sparams.dynatemp_range},
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
Expand Down
4 changes: 4 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,10 @@ extern "C" {
int32_t n_logit_bias,
const llama_logit_bias * logit_bias);


// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);

/// @details Sample and accept a token from the idx-th output of the last evaluation
//
// Shorthand for:
Expand Down
91 changes: 74 additions & 17 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <cstring>
#include <ctime>
#include <cfloat>
#include <chrono>
#include <cmath>
#include <numeric>
#include <random>
Expand Down Expand Up @@ -162,6 +163,19 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
cur_p->size = k;
}

static uint32_t get_rng_seed(uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
// use system clock if std::random_device is not a true RNG
static bool is_rd_prng = std::random_device().entropy() == 0;
if (is_rd_prng) {
return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
}
std::random_device rd;
return rd();
}
return seed;
}

// llama_sampler API

const char * llama_sampler_name(const struct llama_sampler * smpl) {
Expand Down Expand Up @@ -387,6 +401,7 @@ struct llama_sampler * llama_sampler_init_greedy() {

struct llama_sampler_dist {
const uint32_t seed;
uint32_t seed_cur;

std::mt19937 rng;
};
Expand Down Expand Up @@ -416,7 +431,8 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample

static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_dist *) smpl->ctx;
ctx->rng = std::mt19937(ctx->seed);
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
}

static void llama_sampler_dist_free(struct llama_sampler * smpl) {
Expand All @@ -433,11 +449,13 @@ static struct llama_sampler_i llama_sampler_dist_i = {
};

struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
/* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_dist {
/* .seed = */ seed,
/* .rng = */ std::mt19937(seed),
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
Expand Down Expand Up @@ -1032,6 +1050,7 @@ struct llama_sampler_mirostat {
const int32_t n_vocab;

const uint32_t seed;
uint32_t seed_cur;

const float tau;
const float eta;
Expand Down Expand Up @@ -1100,7 +1119,8 @@ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sa
static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
ctx->mu = 2.0f*ctx->tau;
ctx->rng = std::mt19937(ctx->seed);
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
}

static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
Expand All @@ -1117,16 +1137,18 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
};

struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
/* .iface = */ &llama_sampler_mirostat_i,
/* .ctx = */ new llama_sampler_mirostat {
/* .n_vocab = */ n_vocab,
/* .seed = */ seed,
/* .tau = */ tau,
/* .eta = */ eta,
/* .m = */ m,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed),
/* .n_vocab = */ n_vocab,
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .tau = */ tau,
/* .eta = */ eta,
/* .m = */ m,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
Expand All @@ -1135,6 +1157,7 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see

struct llama_sampler_mirostat_v2 {
const uint32_t seed;
uint32_t seed_cur;

const float tau;
const float eta;
Expand Down Expand Up @@ -1179,7 +1202,8 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t
static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
ctx->mu = 2.0f*ctx->tau;
ctx->rng = std::mt19937(ctx->seed);
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
}

static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
Expand Down Expand Up @@ -1212,14 +1236,16 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
};

struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
/* .iface = */ &llama_sampler_mirostat_v2_i,
/* .ctx = */ new llama_sampler_mirostat_v2 {
/* .seed = */ seed,
/* .tau = */ tau,
/* .eta = */ eta,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed),
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .tau = */ tau,
/* .eta = */ eta,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
Expand Down Expand Up @@ -1505,6 +1531,8 @@ struct llama_sampler * llama_sampler_init_penalties(
ignore_eos = false;
}

penalty_last_n = std::max(penalty_last_n, 0);

return new llama_sampler {
/* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties {
Expand Down Expand Up @@ -1568,6 +1596,7 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
}
}
}

static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
Expand Down Expand Up @@ -1599,3 +1628,31 @@ struct llama_sampler * llama_sampler_init_logit_bias(
},
};
}

// utils

uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
if (smpl->iface == &llama_sampler_dist_i) {
return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
}

if (smpl->iface == &llama_sampler_mirostat_i) {
return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
}

if (smpl->iface == &llama_sampler_mirostat_v2_i) {
return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
}

if (smpl->iface == &llama_sampler_chain_i) {
const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
const uint32_t seed = llama_sampler_get_seed(*it);
if (seed != LLAMA_DEFAULT_SEED) {
return seed;
}
}
}

return LLAMA_DEFAULT_SEED;
}
Loading