From 14df08c6e171d846d868981e6933b16de22fe9b9 Mon Sep 17 00:00:00 2001 From: shraddhazpy Date: Fri, 31 Jan 2025 09:28:46 +0530 Subject: [PATCH] feat: tokenization Signed-off-by: shraddhazpy --- backend/cpp/llama/grpc-server.cpp | 12 ++++++++++++ core/backend/tokenize.go | 11 +++++------ core/http/endpoints/localai/tokenize.go | 5 ++--- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index 1e9a35517ff0..4daf84c6963a 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -2542,6 +2542,18 @@ class BackendServiceImpl final : public backend::Backend::Service { return grpc::Status::OK; } + grpc::Status TokenizeString(ServerContext* context, const backend::PredictOptions* request, backend::TokenizationResponse* response){ + json data = parse_options(false, request, llama); + + std::vector tokens = llama.tokenize(data["prompt"],false); + + for (int i=0 ; i< tokens.size(); i++){ + response->add_tokens(tokens[i]); + } + + return grpc::Status::OK; + } + grpc::Status GetMetrics(ServerContext* context, const backend::MetricsRequest* request, backend::MetricsResponse* response) { llama_client_slot* active_slot = llama.get_active_slot(); diff --git a/core/backend/tokenize.go b/core/backend/tokenize.go index 2f813e18736b..1783083ba1a6 100644 --- a/core/backend/tokenize.go +++ b/core/backend/tokenize.go @@ -16,12 +16,7 @@ func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.Bac opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile)) - if backendConfig.Backend == "" { - inferenceModel, err = loader.Load(opts...) - } else { - opts = append(opts, model.WithBackendString(backendConfig.Backend)) - inferenceModel, err = loader.Load(opts...) - } + inferenceModel, err = loader.Load(opts...) if err != nil { return schema.TokenizeResponse{}, err } @@ -35,6 +30,10 @@ func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.Bac return schema.TokenizeResponse{}, err } + if resp.Tokens == nil { + resp.Tokens = make([]int32, 0) + } + return schema.TokenizeResponse{ Tokens: resp.Tokens, }, nil diff --git a/core/http/endpoints/localai/tokenize.go b/core/http/endpoints/localai/tokenize.go index da110bf864e0..faa8a0a4a1b3 100644 --- a/core/http/endpoints/localai/tokenize.go +++ b/core/http/endpoints/localai/tokenize.go @@ -12,6 +12,7 @@ import ( // TokenizeEndpoint exposes a REST API to tokenize the content // @Summary Tokenize the input. +// @Param request body schema.TokenizeRequest true "Request" // @Success 200 {object} schema.TokenizeResponse "Response" // @Router /v1/tokenize [post] func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error { @@ -51,8 +52,6 @@ func TokenizeEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, app return err } - c.JSON(tokenResponse) - return nil - + return c.JSON(tokenResponse) } }