Skip to content

Commit

Permalink
Merge pull request #316 from ngxson/xsn/llama_batch_remove_compat
Browse files Browse the repository at this point in the history
Xsn/llama batch remove compat
  • Loading branch information
Nexesenex authored Oct 15, 2024
2 parents 7eee341 + 4be7ecf commit bbb1ca9
Show file tree
Hide file tree
Showing 36 changed files with 980 additions and 845 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ variety of hardware - locally and in the cloud.
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
- AVX, AVX2 and AVX512 support for x86 architectures
- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP)
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads MTT GPUs via MUSA)
- Vulkan and SYCL backend support
- CPU+GPU hybrid inference to partially accelerate models larger than the total VRAM capacity

Expand Down Expand Up @@ -413,7 +413,7 @@ Please refer to [Build llama.cpp locally](./docs/build.md)
| [BLAS](./docs/build.md#blas-build) | All |
| [BLIS](./docs/backend/BLIS.md) | All |
| [SYCL](./docs/backend/SYCL.md) | Intel and Nvidia GPU |
| [MUSA](./docs/build.md#musa) | Moore Threads GPU |
| [MUSA](./docs/build.md#musa) | Moore Threads MTT GPU |
| [CUDA](./docs/build.md#cuda) | Nvidia GPU |
| [hipBLAS](./docs/build.md#hipblas) | AMD GPU |
| [Vulkan](./docs/build.md#vulkan) | GPU |
Expand Down
267 changes: 112 additions & 155 deletions common/arg.cpp

Large diffs are not rendered by default.

22 changes: 19 additions & 3 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <algorithm>
#include <cinttypes>
#include <climits>
#include <cmath>
#include <codecvt>
#include <cstdarg>
Expand All @@ -23,10 +24,10 @@
#include <regex>
#include <sstream>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <thread>

#if defined(__APPLE__) && defined(__MACH__)
#include <sys/types.h>
Expand Down Expand Up @@ -400,6 +401,21 @@ std::string common_params_get_system_info(const common_params & params) {
// String utils
//

std::string string_format(const char * fmt, ...) {
va_list ap;
va_list ap2;
va_start(ap, fmt);
va_copy(ap2, ap);
int size = vsnprintf(NULL, 0, fmt, ap);
GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT
std::vector<char> buf(size + 1);
int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2);
GGML_ASSERT(size2 == size);
va_end(ap2);
va_end(ap);
return std::string(buf.data(), size);
}

std::vector<std::string> string_split(std::string input, char separator) {
std::vector<std::string> parts;
size_t separator_pos = input.find(separator);
Expand Down Expand Up @@ -939,7 +955,7 @@ struct common_init_result common_init_from_params(common_params & params) {
}

if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == -1) {
decoder_start_token_id = bos;
Expand All @@ -948,7 +964,7 @@ struct common_init_result common_init_from_params(common_params & params) {
tmp.push_back(decoder_start_token_id);
}
if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
}
llama_kv_cache_clear(lctx);
llama_synchronize(lctx);
Expand Down
20 changes: 16 additions & 4 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ struct common_params {
std::string hostname = "127.0.0.1";
std::string public_path = ""; // NOLINT
std::string chat_template = ""; // NOLINT
std::string system_prompt = ""; // NOLINT
bool enable_chat_template = true;

std::vector<std::string> api_keys;
Expand Down Expand Up @@ -352,15 +351,28 @@ void common_init();

std::string common_params_get_system_info(const common_params & params);

bool parse_cpu_range(const std::string& range, bool(&boolmask)[GGML_MAX_N_THREADS]);
bool parse_cpu_mask(const std::string& mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
void postprocess_cpu_params(cpu_params& cpuparams, const cpu_params* role_model = nullptr);
bool parse_cpu_range(const std::string & range, bool(&boolmask)[GGML_MAX_N_THREADS]);
bool parse_cpu_mask(const std::string & mask, bool(&boolmask)[GGML_MAX_N_THREADS]);
void postprocess_cpu_params(cpu_params & cpuparams, const cpu_params * role_model = nullptr);
bool set_process_priority(enum ggml_sched_priority prio);

//
// String utils
//

#ifdef __GNUC__
#ifdef __MINGW32__
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__)))
#else
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
#endif
#else
#define LLAMA_COMMON_ATTRIBUTE_FORMAT(...)
#endif

LLAMA_COMMON_ATTRIBUTE_FORMAT(1, 2)
std::string string_format(const char * fmt, ...);

std::vector<std::string> string_split(std::string input, char separator);

std::string string_strip(const std::string & str);
Expand Down
8 changes: 8 additions & 0 deletions docs/build.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ The following compilation options are also available to tweak performance:
### MUSA
This provides GPU acceleration using the MUSA cores of your Moore Threads MTT GPU. Make sure to have the MUSA SDK installed. You can download it from here: [MUSA SDK](https://developer.mthreads.com/sdk/download/musa).
- Using `make`:
```bash
make GGML_MUSA=1
Expand All @@ -209,6 +211,12 @@ The following compilation options are also available to tweak performance:
cmake --build build --config Release
```
The environment variable [`MUSA_VISIBLE_DEVICES`](https://docs.mthreads.com/musa-sdk/musa-sdk-doc-online/programming_guide/Z%E9%99%84%E5%BD%95/) can be used to specify which GPU(s) will be used.
The environment variable `GGML_CUDA_ENABLE_UNIFIED_MEMORY=1` can be used to enable unified memory in Linux. This allows swapping to system RAM instead of crashing when the GPU VRAM is exhausted.
Most of the compilation options available for CUDA should also be available for MUSA, though they haven't been thoroughly tested yet.

### hipBLAS

This provides BLAS acceleration on HIP-supported AMD GPUs.
Expand Down
1 change: 0 additions & 1 deletion examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ int main(int argc, char ** argv) {
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};

const int ret = llama_decode(ctx, batch_view);
Expand Down
2 changes: 1 addition & 1 deletion examples/cvector-generator/cvector-generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {

static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
llama_kv_cache_clear(ctx);
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) {
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion examples/eval-callback/eval-callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ static bool run(llama_context * ctx, const common_params & params) {

std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);

if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) {
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
Expand Down
13 changes: 11 additions & 2 deletions examples/imatrix/imatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,8 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
// clear the KV cache
llama_kv_cache_clear(ctx);

llama_batch batch = llama_batch_init(n_batch, 0, 1);

for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
Expand All @@ -508,9 +510,14 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
}

// TODO: use batch.logits to save computations instead of relying on logits_all == true
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
common_batch_clear(batch);
for (int i = 0; i < batch_size; i++) {
common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true);
}

if (llama_decode(ctx, batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
llama_batch_free(batch);
return false;
}

Expand All @@ -523,6 +530,8 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
}
}

llama_batch_free(batch);

const auto t_end = std::chrono::high_resolution_clock::now();

if (i == 0) {
Expand Down
18 changes: 9 additions & 9 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,11 @@ int main(int argc, char ** argv) {
std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false);

GGML_ASSERT(llama_token_prefix(model) >= 0);
GGML_ASSERT(llama_token_suffix(model) >= 0);
GGML_ASSERT(llama_token_fim_pre(model) >= 0);
GGML_ASSERT(llama_token_fim_suf(model) >= 0);

inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model));
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model));

embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
embd_end = params.spm_infill ? inp_pfx : inp_sfx;
Expand All @@ -218,7 +218,7 @@ int main(int argc, char ** argv) {
}
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());

const llama_token middle_token = llama_token_middle(model);
const llama_token middle_token = llama_token_fim_mid(model);
if (middle_token >= 0) {
embd_inp.push_back(middle_token);
}
Expand Down Expand Up @@ -376,7 +376,7 @@ int main(int argc, char ** argv) {
n_past, n_left, n_ctx, params.n_keep, n_discard);

llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past + 1, -n_discard);

n_past -= n_discard;

Expand All @@ -396,7 +396,7 @@ int main(int argc, char ** argv) {

LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());

if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) {
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}
Expand Down Expand Up @@ -508,8 +508,8 @@ int main(int argc, char ** argv) {
std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false);

inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model));
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model));

embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
embd_end = params.spm_infill ? inp_pfx : inp_sfx;
Expand Down
16 changes: 8 additions & 8 deletions examples/llama-bench/llama-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1428,7 +1428,7 @@ struct sql_printer : public printer {
}
};

static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) {
static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);

const llama_model * model = llama_get_model(ctx);
Expand All @@ -1444,14 +1444,14 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
for (int i = 1; i < n_tokens; i++) {
tokens[i] = std::rand() % n_vocab;
}
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0));
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens));
n_processed += n_tokens;
}

llama_synchronize(ctx);
}

static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads);

const llama_model * model = llama_get_model(ctx);
Expand All @@ -1460,7 +1460,7 @@ static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads)
llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab;

for (int i = 0; i < n_gen; i++) {
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0));
llama_decode(ctx, llama_batch_get_one(&token, 1));
llama_synchronize(ctx);
token = std::rand() % n_vocab;
}
Expand Down Expand Up @@ -1596,13 +1596,13 @@ int main(int argc, char ** argv) {
fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup prompt run\n", params_idx, params_count);
}
//test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
}
if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%ld: warmup generation run\n", params_idx, params_count);
}
test_gen(ctx, 1, 0, t.n_threads);
test_gen(ctx, 1, t.n_threads);
}

for (int i = 0; i < params.reps; i++) {
Expand All @@ -1614,13 +1614,13 @@ int main(int argc, char ** argv) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%ld: prompt run %d/%d\n", params_idx, params_count, i + 1, params.reps);
}
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads);
}
if (t.n_gen > 0) {
if (params.progress) {
fprintf(stderr, "llama-bench: benchmark %d/%ld: generation run %d/%d\n", params_idx, params_count, i + 1, params.reps);
}
test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads);
test_gen(ctx, t.n_gen, t.n_threads);
}

uint64_t t_ns = get_time_ns() - t_start;
Expand Down
3 changes: 0 additions & 3 deletions examples/llama.android/llama/src/main/cpp/llama-android.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,6 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
nullptr,
nullptr,
nullptr,
0,
0,
0,
};

if (embd) {
Expand Down
2 changes: 1 addition & 1 deletion examples/llava/llava-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
if (n_eval > n_batch) {
n_eval = n_batch;
}
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval, *n_past, 0))) {
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
return false;
}
Expand Down
38 changes: 36 additions & 2 deletions examples/llava/llava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,39 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co
return true;
}

struct llava_embd_batch {
std::vector<llama_pos> pos;
std::vector<int32_t> n_seq_id;
std::vector<llama_seq_id> seq_id_0;
std::vector<llama_seq_id *> seq_ids;
std::vector<int8_t> logits;
llama_batch batch;
llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) {
pos .resize(n_tokens);
n_seq_id.resize(n_tokens);
seq_ids .resize(n_tokens + 1);
logits .resize(n_tokens);
seq_id_0.resize(1);
seq_id_0[0] = seq_id;
seq_ids [n_tokens] = nullptr;
batch = {
/*n_tokens =*/ n_tokens,
/*tokens =*/ nullptr,
/*embd =*/ embd,
/*pos =*/ pos.data(),
/*n_seq_id =*/ n_seq_id.data(),
/*seq_id =*/ seq_ids.data(),
/*logits =*/ logits.data(),
};
for (int i = 0; i < n_tokens; i++) {
batch.pos [i] = pos_0 + i;
batch.n_seq_id[i] = 1;
batch.seq_id [i] = seq_id_0.data();
batch.logits [i] = false;
}
}
};

bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama));

Expand All @@ -409,8 +442,9 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) {
float * embd = image_embed->embed+i*n_embd;
llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0);
if (llama_decode(ctx_llama, llava_batch.batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
Expand Down
Loading

0 comments on commit bbb1ca9

Please sign in to comment.