Skip to content

Commit

Permalink
feat(llama.cpp): add flash_attention and no_kv_offloading (#2310)
Browse files Browse the repository at this point in the history
feat(llama.cpp): add flash_attn and no_kv_offload

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
  • Loading branch information
mudler authored May 13, 2024
1 parent 7123d07 commit e49ea01
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 0 deletions.
3 changes: 3 additions & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ message ModelOptions {
float YarnBetaSlow = 47;

string Type = 49;

bool FlashAttention = 56;
bool NoKVOffload = 57;
}

message Result {
Expand Down
3 changes: 3 additions & 0 deletions backend/cpp/llama/grpc-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2254,6 +2254,9 @@ static void params_parse(const backend::ModelOptions* request,
}
params.use_mlock = request->mlock();
params.use_mmap = request->mmap();
params.flash_attn = request->flashattention();
params.no_kv_offload = request->nokvoffload();

params.embedding = request->embeddings();

if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
Expand Down
2 changes: 2 additions & 0 deletions core/backend/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
MaxModelLen: int32(c.MaxModelLen),
TensorParallelSize: int32(c.TensorParallelSize),
MMProj: c.MMProj,
FlashAttention: c.FlashAttention,
NoKVOffload: c.NoKVOffloading,
YarnExtFactor: c.YarnExtFactor,
YarnAttnFactor: c.YarnAttnFactor,
YarnBetaFast: c.YarnBetaFast,
Expand Down
3 changes: 3 additions & 0 deletions core/config/backend_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ type LLMConfig struct {
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
MMProj string `yaml:"mmproj"`

FlashAttention bool `yaml:"flash_attention"`
NoKVOffloading bool `yaml:"no_kv_offloading"`

RopeScaling string `yaml:"rope_scaling"`
ModelType string `yaml:"type"`

Expand Down

0 comments on commit e49ea01

Please sign in to comment.