Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
[LLM Runtime] Add jblas split weight interface and support jblas mode…
Browse files Browse the repository at this point in the history
…ls (#639)

* [LLM Runtime] Add jblas split weight interface and support jblas models

Signed-off-by: Clark Chin <xi2.chen@intel.com>
  • Loading branch information
ClarkChin08 authored Nov 25, 2023
1 parent a87ab22 commit 22ceda4
Show file tree
Hide file tree
Showing 6 changed files with 211 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class WeightS8ScaleFp32 {
auto Tscales = utils::amalloc<float>(ssize);
auto Tzps = utils::amalloc<int8_t>(ptr->mIsAsym ? ssize : 0);
quantizeWeight(N, K, B, ldb, ptr->mBlockSize, tmpq, Tscales, Tzps);
packQWeight(N, K, tmpq, ldb, Tscales, Tzps, stor);
packQWeight(N, K, tmpq, N, Tscales, Tzps, stor);
utils::afree(tmpq);
utils::afree(Tscales);
utils::afree(Tzps);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "jblas_common.hpp"
#include "jblas/jit_blas_weight_compression.h"
using namespace jblas;
using namespace ne_jblas;

void jblas_init() {
GetCPUDevice();
Expand All @@ -34,3 +36,122 @@ int jblas_set_threads(int _nth) {
jblas::utils::parallel::CpuDevice::getInstance()->setThreads(_nth);
return jblas::utils::parallel::CpuDevice::getInstance()->getThreads();
}

template <template <class, JBLAS_ISA> class Wei_T, class GC_T>
void jblas_unpackweight(void* wptr, int n, int k, float* fp32data, int ld) {
GetCPUDevice();
using ProB_AVX512 = Wei_T<GC_T, JblasAVX512F>;
using Prob_AVX2 = Wei_T<GC_T, JblasAVX2>;
if (_cd->AVX512F()) {
static ProB_AVX512 prob;
prob.unpackWeight(n, k, wptr, fp32data, ld);
return;
}
if (_cd->AVX2()) {
static Prob_AVX2 prob;
prob.unpackWeight(n, k, wptr, fp32data, ld);
return;
}
}

void jblas_unpackweight_fp32(void* wptr, int n, int k, float* fp32data, int ld) {
auto wtmp = prologue::weight_comp::gemm_kblcok::PackedWeightParser::deserialBuffer(wptr);
if (wtmp != nullptr) {
if (wtmp->mPrologueID == int(WeightCompType::WeightS4ClipScaleFp32)) {
if (wtmp->mCoreType == GcCompFp32::TYPE) {
jblas_unpackweight<WeiS4ClipFp32, GcCompFp32>(wtmp, n, k, fp32data, ld);
}
if (wtmp->mCoreType == GcCompInt8::TYPE || wtmp->mCoreType == GcCompInt8KBlock::TYPE) {
jblas_unpackweight<WeiS4ClipFp32, GcCompInt8>(wtmp, n, k, fp32data, ld);
}
} else if (wtmp->mPrologueID == int(WeightCompType::WeightS8ScaleFp32)) {
if (wtmp->mCoreType == GcCompFp32::TYPE) {
jblas_unpackweight<WeiS8Fp32, GcCompFp32>(wtmp, n, k, fp32data, ld);
}
if (wtmp->mCoreType == GcCompInt8::TYPE || wtmp->mCoreType == GcCompInt8KBlock::TYPE) {
jblas_unpackweight<WeiS8Fp32, GcCompInt8>(wtmp, n, k, fp32data, ld);
}
} else if (wtmp->mPrologueID == int(WeightCompType::WeightS8ScaleFp32PerChannelN)) {
if (wtmp->mCoreType == GcCompFp32::TYPE) {
jblas_unpackweight<WeiS8Fp32PerN, GcCompFp32>(wtmp, n, k, fp32data, ld);
}
if (wtmp->mCoreType == GcCompInt8::TYPE || wtmp->mCoreType == GcCompInt8KBlock::TYPE) {
jblas_unpackweight<WeiS8Fp32PerN, GcCompInt8>(wtmp, n, k, fp32data, ld);
}
} else if (wtmp->mPrologueID == int(WeightCompType::WeightS4ClipScaleFp32PerChannelN)) {
if (wtmp->mCoreType == GcCompFp32::TYPE) {
jblas_unpackweight<WeiS4ClipFp32PerN, GcCompFp32>(wtmp, n, k, fp32data, ld);
}
if (wtmp->mCoreType == GcCompInt8::TYPE || wtmp->mCoreType == GcCompInt8KBlock::TYPE) {
jblas_unpackweight<WeiS4ClipFp32PerN, GcCompInt8>(wtmp, n, k, fp32data, ld);
}
}
}
safe_delete(wtmp);
}

template <template <class, JBLAS_ISA> class Wei_T, class GC_T>
void jblas_packweight(const float* fp32data, void* dstptr, int n, int k, int ld, void* srcptr) {
GetCPUDevice();
using ProB_AVX512 = Wei_T<GC_T, JblasAVX512F>;
using Prob_AVX2 = Wei_T<GC_T, JblasAVX2>;
using ST = typename Prob_AVX2::StorageWeight;
static Prob_AVX2 prob;
ST tmp(gemm::GemmCoreType::Undef);
auto src = (ST*)srcptr;
if constexpr (std::is_same_v<Prob_AVX2, WeiS4ClipFp32<GC_T, JblasAVX2>> ||
std::is_same_v<Prob_AVX2, WeiS8Fp32<GC_T, JblasAVX2>>) {
tmp = prob.createStorage(n, k, src->mBlockSize, src->mIsAsym);
}
if constexpr (std::is_same_v<Prob_AVX2, WeiS4ClipFp32PerN<GC_T, JblasAVX2>> ||
std::is_same_v<Prob_AVX2, WeiS8Fp32PerN<GC_T, JblasAVX2>>) {
tmp = prob.createStorage(n, k, src->mIsAsym);
}
tmp.assign((int8_t*)dstptr);
if (_cd->AVX512F()) {
static ProB_AVX512 prob;
prob.packWeight(n, k, fp32data, ld, &tmp);
return;
}
if (_cd->AVX2()) {
prob.packWeight(n, k, fp32data, ld, &tmp);
return;
}
}

void jblas_packweight_copyattr(const float* f32ptr, void* dstpr, int n, int k, int ld, void* srcptr) {
auto wtmp = prologue::weight_comp::gemm_kblcok::PackedWeightParser::deserialBuffer(srcptr);
if (wtmp != nullptr) {
if (wtmp->mPrologueID == int(WeightCompType::WeightS4ClipScaleFp32)) {
if (wtmp->mCoreType == GcCompFp32::TYPE) {
jblas_packweight<WeiS4ClipFp32, GcCompFp32>(f32ptr, dstpr, n, k, ld, wtmp);
}
if (wtmp->mCoreType == GcCompInt8KBlock::TYPE) {
jblas_packweight<WeiS4ClipFp32, GcCompInt8KBlock>(f32ptr, dstpr, n, k, ld, wtmp);
}
} else if (wtmp->mPrologueID == int(WeightCompType::WeightS8ScaleFp32)) {
if (wtmp->mCoreType == GcCompFp32::TYPE) {
jblas_packweight<WeiS8Fp32, GcCompFp32>(f32ptr, dstpr, n, k, ld, wtmp);
}
if (wtmp->mCoreType == GcCompInt8KBlock::TYPE) {
jblas_packweight<WeiS8Fp32, GcCompInt8KBlock>(f32ptr, dstpr, n, k, ld, wtmp);
}
} else if (wtmp->mPrologueID == int(WeightCompType::WeightS8ScaleFp32PerChannelN)) {
if (wtmp->mCoreType == GcCompFp32::TYPE) {
jblas_packweight<WeiS8Fp32PerN, GcCompFp32>(f32ptr, dstpr, n, k, ld, wtmp);
}
if (wtmp->mCoreType == GcCompInt8::TYPE) {
jblas_packweight<WeiS8Fp32PerN, GcCompInt8>(f32ptr, dstpr, n, k, ld, wtmp);
}

} else if (wtmp->mPrologueID == int(WeightCompType::WeightS4ClipScaleFp32PerChannelN)) {
if (wtmp->mCoreType == GcCompFp32::TYPE) {
jblas_packweight<WeiS4ClipFp32PerN, GcCompFp32>(f32ptr, dstpr, n, k, ld, wtmp);
}
if (wtmp->mCoreType == GcCompInt8::TYPE) {
jblas_packweight<WeiS4ClipFp32PerN, GcCompInt8>(f32ptr, dstpr, n, k, ld, wtmp);
}
}
}
safe_delete(wtmp);
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ void jblas_fusion_FFN_Add_GeLu_f32f32_forward(float* activation, void* w1ptr, vo
float* tmp1, float* output, int seq, int fin, int fmid, int fout,
bool boardcast_bias, void* workspace);

void jblas_unpackweight_fp32(void* wptr, int n, int k, float* fp32data, int ld);
// packweight to dstptr, copy weight attributes from srcptr
void jblas_packweight_copyattr(const float* f32ptr, void* dstpr, int n, int k, int ld, void* srcptr);
#ifdef __cplusplus
}
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,17 +230,18 @@ static bool gptj_model_eval_internal(model_context& lctx, const model_input* inp
n_past * ne_element_size(kv_self.k)));
} else {
// batch K
Kcur_bs[i] = ne_permute(ctx0,
ne_view_4d(ctx0, Kcur, head_size, n_head, N, 1, ne_element_size(Kcur) * head_size,
ne_element_size(Kcur) * n_embd, ne_element_size(Kcur) * n_embd * N,
i * ne_element_size(Kcur) * n_embd * N),
0, 2, 1, 3);
k_bs[i] =
ne_view_4d(ctx0, kv_self.k, head_size, N, n_head, 1, ne_element_size(kv_self.k) * head_size,
ne_element_size(kv_self.k) * head_size * n_ctx, ne_element_size(kv_self.k) * n_embd * n_ctx,
((il * n_ctx) * ne_element_size(kv_self.k) * n_embd * kv_n_ctx_block +
block_idx * n_ctx * n_embd * ne_element_size(kv_self.k) +
head_size * n_past * ne_element_size(kv_self.k)));
Kcur_bs[i] = ne_permute(
ctx0,
ne_view_4d(ctx0, Kcur, head_size, n_head, N, 1, ne_element_size(Kcur) * head_size,
ne_element_size(Kcur) * head_size * n_head, ne_element_size(Kcur) * head_size * n_head * N,
i * ne_element_size(Kcur) * head_size * n_head * N),
0, 2, 1, 3);
k_bs[i] = ne_view_4d(ctx0, kv_self.k, head_size, N, n_head, 1, ne_element_size(kv_self.k) * head_size,
ne_element_size(kv_self.k) * head_size * n_ctx,
ne_element_size(kv_self.k) * head_size * n_head * n_ctx,
((il * n_ctx) * ne_element_size(kv_self.k) * head_size * n_head * kv_n_ctx_block +
block_idx * n_ctx * head_size * n_head * ne_element_size(kv_self.k) +
head_size * n_past * ne_element_size(kv_self.k)));

// batch V
Vcur_bs[i] = ne_permute(
Expand Down Expand Up @@ -432,15 +433,15 @@ static bool gptj_model_eval_internal(model_context& lctx, const model_input* inp

struct ne_tensor* FFN_out = ne_mul_mat(ctx0, model.layers[il].ffn[2], cur);
ne_set_name(FFN_out, "FFN_out");

#ifdef NE_TP_MODEL
// if tp model then all reduce as the weight has been split
if (enable_tp) {
FFN_out = ne_all_reduce(ctx0, FFN_out);
}
#endif
// NOTICE: when TP, only master node add this bias
cur = ne_add(ctx0, ne_repeat(ctx0, model.layers[il].ffn[3], FFN_out), FFN_out);
}
#ifdef NE_TP_MODEL
// if tp model then all reduce as the weight has been split
if (enable_tp) {
cur = ne_all_reduce(ctx0, cur);
}
#endif
cur = ne_add(ctx0, cur, inpFF);
// if (il == 20) {
// cur = ne_dump_tensor(ctx0, cur);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,13 @@ static bool llama_model_eval_internal(model_context& lctx, const model_input* in
cur = ne_silu(ctx0, cur);
cur = ne_mul(ctx0, cur, tmp);
cur = ne_mul_mat(ctx0, model.layers[il].ffn[1], cur);
}
#ifdef NE_TP_MODEL
// ffn2 and ffn0 use split row, ffn1 use split column
if (enable_tp) {
cur = ne_all_reduce(ctx0, cur);
}
#endif
// ffn2 and ffn0 use split row, ffn1 use split column
if (enable_tp) {
cur = ne_all_reduce(ctx0, cur);
}
#endif
}

cur = ne_add(ctx0, cur, inpFF);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#define NE_MEM_ALIGN 16
#endif

#include "core/layers/jblas_common.hpp"
#include "core/ne_layers.h"
#include "models/model_utils/util.h"
#include "models/models.h"
Expand Down Expand Up @@ -76,7 +77,7 @@ struct model_load_tensor_shard {
void calc_size() { size = model_calc_tensor_size(ne, type); }
};

enum model_split_type { SPLIT_NONE, SPLIT_BY_COLUMNS, SPLIT_BY_ROWS, TP_1D_ROW, TP_1D_COLUMN };
enum model_split_type { SPLIT_NONE, SPLIT_BY_COLUMNS, SPLIT_BY_ROWS, TP_1D_ROW, TP_1D_COLUMN, TP_1D_ONLY_MASTER };

struct model_load_tensor {
std::vector<model_load_tensor_shard> shards;
Expand Down Expand Up @@ -152,6 +153,9 @@ struct model_load_tensor {
name.find(".feed_forward.w2.weight") != std::string::npos) {
split_type = TP_1D_COLUMN;
}
if (name.find(".mlp.fc_out.bias") != std::string::npos) {
split_type = TP_1D_ONLY_MASTER;
}
}
#endif
}
Expand Down Expand Up @@ -191,6 +195,9 @@ struct model_load_tensor {
ne = {first_shard.ne[0] / world_size, first_shard.ne[1]};
}
break;
case TP_1D_ONLY_MASTER:
ne = first_shard.ne;
break;
#endif
}
}
Expand Down Expand Up @@ -614,6 +621,18 @@ struct model_model_loader {
}
}

void jblas_split_weight(void** src, void** dst, size_t src_n, size_t src_k, size_t dst_n, size_t dst_k, size_t n_rank,
size_t k_rank) {
auto src_fp32 = (float*)malloc(src_n * src_k * sizeof(float));
if (src_fp32 == nullptr) {
assert(0);
}
jblas_unpackweight_fp32(*src, src_n, src_k, src_fp32, src_n);
// layout will be K * N in the buffer
auto dst_fp32 = src_fp32 + k_rank * dst_k * src_n + n_rank * dst_n;
jblas_packweight_copyattr(dst_fp32, *dst, dst_n, dst_k, src_n, *src);
free(src_fp32);
}
void load_data_for(model_load_tensor& lt) {
if (use_mmap) {
MODEL_ASSERT(lt.shards.size() == 1);
Expand Down Expand Up @@ -659,10 +678,20 @@ struct model_model_loader {
model_buffer tmp_buf;
model_file& file = file_loaders.at(shard.file_idx)->file;
file.seek(shard.file_off, SEEK_SET);
tmp_buf.resize(lt.size * lt.world_size);
file.read_raw(tmp_buf.addr, lt.size * lt.world_size);
// only copy part of weight form the tmp_buf of origin file
memcpy(lt.data, tmp_buf.addr + lt.rank * lt.size, lt.size);
size_t num_rows = lt.ne.size() == 1 ? 1 : lt.ne.at(1);
if (lt.type == NE_TYPE_JBLAS) {
tmp_buf.resize(shard.size);
file.read_raw(tmp_buf.addr, shard.size);
void* dst_data = (void*)lt.data;
void* src_data = (void*)(tmp_buf.addr);
jblas_split_weight(&src_data, &dst_data, lt.world_size * num_rows, lt.ne.at(0), num_rows, lt.ne.at(0), lt.rank,
0);
} else {
// only copy part of weight form the tmp_buf of origin file
tmp_buf.resize(lt.size * lt.world_size);
file.read_raw(tmp_buf.addr, lt.size * lt.world_size);
memcpy(lt.data, tmp_buf.addr + lt.rank * lt.size, lt.size);
}
} else if (lt.split_type == TP_1D_COLUMN) {
if (lt.size == 0) {
return;
Expand All @@ -671,18 +700,36 @@ struct model_model_loader {
model_buffer tmp_buf;
model_file& file = file_loaders.at(shard.file_idx)->file;
file.seek(shard.file_off, SEEK_SET);
tmp_buf.resize(lt.size * lt.world_size);
file.read_raw(tmp_buf.addr, lt.size * lt.world_size);
size_t offset = 0;
size_t num_rows = lt.ne.size() == 1 ? 1 : lt.ne.at(1);
// different data type may have differnet per_row_size
size_t per_row_size = lt.size / num_rows;
for (size_t i = 0; i < num_rows; ++i) {
memcpy(lt.data + offset, tmp_buf.addr + lt.rank * per_row_size + i * lt.world_size * per_row_size,
per_row_size);
offset += per_row_size;
if (lt.type == NE_TYPE_JBLAS) {
tmp_buf.resize(shard.size);
file.read_raw(tmp_buf.addr, shard.size);
void* dst_data = (void*)lt.data;
void* src_data = (void*)(tmp_buf.addr);
jblas_split_weight(&src_data, &dst_data, num_rows, lt.world_size * lt.ne.at(0), num_rows, lt.ne.at(0), 0,
lt.rank);
} else {
tmp_buf.resize(lt.size * lt.world_size);
file.read_raw(tmp_buf.addr, lt.size * lt.world_size);
size_t offset = 0;
// different data type may have differnet per_row_size
size_t per_row_size = lt.size / num_rows;
for (size_t i = 0; i < num_rows; ++i) {
memcpy(lt.data + offset, tmp_buf.addr + lt.rank * per_row_size + i * lt.world_size * per_row_size,
per_row_size);
offset += per_row_size;
}
MODEL_ASSERT(offset == lt.size);
}
} else if (lt.split_type == TP_1D_ONLY_MASTER) {
// only master node load the tensor, other node set to zero
model_file& file = file_loaders.at(lt.shards.at(0).file_idx)->file;
file.seek(lt.shards.at(0).file_off, SEEK_SET);
if (lt.rank == 0) {
file.read_raw(lt.data, lt.size);
} else {
memset(lt.data, 0, lt.size);
}
MODEL_ASSERT(offset == lt.size);
}
#endif
if (0) {
Expand Down

0 comments on commit 22ceda4

Please sign in to comment.