Skip to content

Commit

Permalink
whisper : use backend registry
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Nov 20, 2024
1 parent c800966 commit fbae8dc
Showing 1 changed file with 60 additions and 123 deletions.
183 changes: 60 additions & 123 deletions src/whisper.cpp
Original file line number Diff line number Diff line change
@@ -1,43 +1,19 @@
#include "whisper.h"

#ifdef WHISPER_USE_COREML
#include "coreml/whisper-encoder.h"
#endif

#include "ggml-cpu.h"

#ifdef GGML_USE_METAL
#include "ggml-metal.h"
#endif

#ifdef GGML_USE_CUDA
#include "ggml-cuda.h"
#endif

#ifdef GGML_USE_SYCL
#include "ggml-sycl.h"
#endif

#ifdef GGML_USE_VULKAN
#include "ggml-vulkan.h"
#endif
#include "ggml.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"

#ifdef GGML_USE_BLAS
#include "ggml-blas.h"
#ifdef WHISPER_USE_COREML
#include "coreml/whisper-encoder.h"
#endif

#ifdef WHISPER_USE_OPENVINO
#include "openvino/whisper-openvino-encoder.h"
#endif

#ifdef GGML_USE_CANN
#include "ggml-cann.h"
#endif

#include "ggml.h"
#include "ggml-alloc.h"
#include "ggml-backend.h"

#include <atomic>
#include <algorithm>
#include <cassert>
Expand Down Expand Up @@ -195,14 +171,13 @@ static bool ggml_graph_compute_helper(

for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) {
ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i);
if (ggml_backend_is_cpu(backend)) {
ggml_backend_cpu_set_n_threads(backend, n_threads);
}
#ifdef GGML_USE_BLAS
if (ggml_backend_is_blas(backend)) {
ggml_backend_blas_set_n_threads(backend, n_threads);
ggml_backend_dev_t dev = ggml_backend_get_device(backend);
ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;

auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
if (fn_set_n_threads) {
fn_set_n_threads(backend, n_threads);
}
#endif
}

bool t = ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS;
Expand Down Expand Up @@ -1260,61 +1235,24 @@ static ggml_backend_t whisper_backend_init_gpu(const whisper_context_params & pa

ggml_log_set(g_state.log_callback, g_state.log_callback_user_data);

#ifdef GGML_USE_CUDA
if (params.use_gpu) {
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
result = ggml_backend_cuda_init(params.gpu_device);
if (!result) {
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
}
}
#endif

#ifdef GGML_USE_METAL
if (params.use_gpu) {
WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
result = ggml_backend_metal_init();
if (!result) {
WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
} else if (!ggml_backend_metal_supports_family(result, 7)) {
WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__);
ggml_backend_free(result);
result = NULL;
}
}
#endif

#ifdef GGML_USE_SYCL
if (params.use_gpu) {
WHISPER_LOG_INFO("%s: using SYCL backend\n", __func__);
result = ggml_backend_sycl_init(params.gpu_device);
if (!result) {
WHISPER_LOG_ERROR("%s: ggml_backend_sycl_init() failed\n", __func__);
}
}
#endif

#ifdef GGML_USE_VULKAN
if (params.use_gpu) {
WHISPER_LOG_INFO("%s: using Vulkan backend\n", __func__);
result = ggml_backend_vk_init(params.gpu_device);
if (!result) {
WHISPER_LOG_ERROR("%s: ggml_backend_vk_init() failed\n", __func__);
}
}
#endif

#ifdef GGML_USE_CANN
if (params.use_gpu) {
WHISPER_LOG_INFO("%s: using CANN backend\n", __func__);
result = ggml_backend_cann_init(params.gpu_device);
if (!result) {
WHISPER_LOG_ERROR("%s: ggml_backend_cann_init() failed\n", __func__);
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
switch (ggml_backend_dev_type(dev)) {
case GGML_BACKEND_DEVICE_TYPE_CPU:
case GGML_BACKEND_DEVICE_TYPE_ACCEL:
// skip CPU backends
break;
case GGML_BACKEND_DEVICE_TYPE_GPU:
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
result = ggml_backend_dev_init(dev, nullptr);
if (!result) {
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
}
break;
}
}
}
#endif

GGML_UNUSED(params);

return result;
}
Expand All @@ -1328,17 +1266,19 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
result.push_back(backend_gpu);
}

#ifdef GGML_USE_BLAS
{
WHISPER_LOG_INFO("%s: using BLAS backend\n", __func__);
ggml_backend_t backend_blas = ggml_backend_blas_init();
if (!backend_blas) {
WHISPER_LOG_ERROR("%s: ggml_backend_blas_init() failed\n", __func__);
} else {
result.push_back(backend_blas);
// ACCEL backends
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) {
WHISPER_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev));
ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
if (!backend) {
WHISPER_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
continue;
}
result.push_back(backend);
}
}
#endif

GGML_UNUSED(params);

Expand All @@ -1348,33 +1288,26 @@ static std::vector<ggml_backend_t> whisper_backend_init(const whisper_context_pa
}

static ggml_backend_buffer_type_t whisper_default_buffer_type(const whisper_context_params & params) {
ggml_backend_buffer_type_t result = nullptr;

params.use_gpu || (result = ggml_backend_cpu_buffer_type());

#ifdef GGML_USE_CUDA
result || (result = ggml_backend_cuda_buffer_type(params.gpu_device));
#endif

#ifdef GGML_USE_METAL
result || (result = ggml_backend_metal_buffer_type());
#endif

#ifdef GGML_USE_SYCL
result || (result = ggml_backend_sycl_buffer_type(params.gpu_device));
#endif

#ifdef GGML_USE_VULKAN
result || (result = ggml_backend_vk_buffer_type(params.gpu_device));
#endif
if (!params.use_gpu) {
return ggml_backend_cpu_buffer_type();
}

#ifdef GGML_USE_CANN
result || (result == ggml_backend_cann_buffer_type(params.gpu_device));
#endif
// if we have a GPU device - use it
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
switch (ggml_backend_dev_type(dev)) {
case GGML_BACKEND_DEVICE_TYPE_CPU:
case GGML_BACKEND_DEVICE_TYPE_ACCEL:
// skip CPU backends
break;

result || (result = ggml_backend_cpu_buffer_type());
case GGML_BACKEND_DEVICE_TYPE_GPU:
WHISPER_LOG_INFO("%s: using device %s (%s)\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev));
return ggml_backend_dev_buffer_type(dev);
}
}

return result;
return ggml_backend_cpu_buffer_type();
}

// load the model from a ggml file
Expand Down Expand Up @@ -3668,8 +3601,7 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn);
WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps);

// TODO: temporary call to force backend registry initialization
WHISPER_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count());
WHISPER_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count());

whisper_context * ctx = new whisper_context;
Expand Down Expand Up @@ -7427,6 +7359,11 @@ static void whisper_log_internal(ggml_log_level level, const char * format, ...)
static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) {
(void) level;
(void) user_data;
#ifndef WHISPER_DEBUG
if (level == GGML_LOG_LEVEL_DEBUG) {
return;
}
#endif
fputs(text, stderr);
fflush(stderr);
}

0 comments on commit fbae8dc

Please sign in to comment.