diff --git a/backend/python/vllm/backend.py b/backend/python/vllm/backend.py index 2d8b55db6190..2cf15c1ca83e 100644 --- a/backend/python/vllm/backend.py +++ b/backend/python/vllm/backend.py @@ -135,6 +135,26 @@ async def Predict(self, request, context): res = await gen.__anext__() return res + def Embedding(self, request, context): + """ + A gRPC method that calculates embeddings for a given sentence. + + Args: + request: An EmbeddingRequest object that contains the request parameters. + context: A grpc.ServicerContext object that provides information about the RPC. + + Returns: + An EmbeddingResult object that contains the calculated embeddings. + """ + print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr) + outputs = self.model.encode(request.Embeddings) + # Check if we have one result at least + if len(outputs) == 0: + context.set_code(grpc.StatusCode.INVALID_ARGUMENT) + context.set_details("No embeddings were calculated.") + return backend_pb2.EmbeddingResult() + return backend_pb2.EmbeddingResult(embeddings=outputs[0].outputs.embedding) + async def PredictStream(self, request, context): """ Generates text based on the given prompt and sampling parameters, and streams the results. diff --git a/backend/python/vllm/test.py b/backend/python/vllm/test.py index 83fb26518e76..9f325b103346 100644 --- a/backend/python/vllm/test.py +++ b/backend/python/vllm/test.py @@ -72,5 +72,28 @@ def test_text(self): except Exception as err: print(err) self.fail("text service failed") + finally: + self.tearDown() + + def test_embedding(self): + """ + This method tests if the embeddings are generated successfully + """ + try: + self.setUp() + with grpc.insecure_channel("localhost:50051") as channel: + stub = backend_pb2_grpc.BackendStub(channel) + response = stub.LoadModel(backend_pb2.ModelOptions(Model="intfloat/e5-mistral-7b-instruct")) + self.assertTrue(response.success) + embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.") + embedding_response = stub.Embedding(embedding_request) + self.assertIsNotNone(embedding_response.embeddings) + # assert that is a list of floats + self.assertIsInstance(embedding_response.embeddings, list) + # assert that the list is not empty + self.assertTrue(len(embedding_response.embeddings) > 0) + except Exception as err: + print(err) + self.fail("Embedding service failed") finally: self.tearDown() \ No newline at end of file