diff --git a/test/pytest/test_example_gpt_fast.py b/test/pytest/test_example_gpt_fast.py index 954d76478e..0888619a17 100644 --- a/test/pytest/test_example_gpt_fast.py +++ b/test/pytest/test_example_gpt_fast.py @@ -49,17 +49,17 @@ MAR_PARAMS = ( { "nproc": 1, - "stream": "false", + "stream": "true", "compile": "false", }, { "nproc": 4, - "stream": "false", + "stream": "true", "compile": "false", }, { "nproc": 4, - "stream": f"false\n speculate_k: 8\n draft_checkpoint_path: '{(LLAMA_MODEL_PATH.parents[1] / 'Llama-2-7b-chat-hf' / 'model_int8.pth').as_posix()}'", + "stream": f"true\n speculate_k: 8\n draft_checkpoint_path: '{(LLAMA_MODEL_PATH.parents[1] / 'Llama-2-7b-chat-hf' / 'model_int8.pth').as_posix()}'", "compile": "true", }, ) @@ -72,7 +72,8 @@ ] EXPECTED_RESULTS = [ - ", Paris, is a city of romance, fashion, and art. The city is home to the Eiffel Tower, the Louvre, and the Arc de Triomphe. Paris is also known for its cafes, restaurants", + # ", Paris, is a city of romance, fashion, and art. The city is home to the Eiffel Tower, the Louvre, and the Arc de Triomphe. Paris is also known for its cafes, restaurants", + " is Paris.\nThe capital of Germany is Berlin.\nThe capital of Italy is Rome.\nThe capital of Spain is Madrid.\nThe capital of the United Kingdom is London.\nThe capital of the European Union is Brussels.\n", ] @@ -223,19 +224,21 @@ def test_gpt_fast_mar(model_name_and_stdout): response = requests.post( url=f"http://localhost:8080/predictions/{model_name}", - data=json.dumps(PROMPTS[0]), + data=json.dumps( + PROMPTS[0], + ), + stream=True, ) assert response.status_code == 200 - # Streaming currently does not work with tp - # assert response.headers["Transfer-Encoding"] == "chunked" + assert response.headers["Transfer-Encoding"] == "chunked" - # prediction = "" - # for chunk in response.iter_content(chunk_size=None): - # if chunk: - # prediction += chunk.decode("utf-8") + prediction = [] + for chunk in response.iter_content(chunk_size=None): + if chunk: + prediction += [chunk.decode("utf-8")] - # assert prediction == EXPECTED_RESULTS[0] + assert len(prediction) > 1 - assert response.text == EXPECTED_RESULTS[0] + assert "".join(prediction) == EXPECTED_RESULTS[0] diff --git a/test/pytest/test_send_intermediate_prediction_response.py b/test/pytest/test_send_intermediate_prediction_response.py new file mode 100644 index 0000000000..ada8b8824f --- /dev/null +++ b/test/pytest/test_send_intermediate_prediction_response.py @@ -0,0 +1,120 @@ +import json +import shutil +from pathlib import Path +from unittest.mock import patch + +import pytest +import requests +import test_utils +from model_archiver import ModelArchiverConfig + +CURR_FILE_PATH = Path(__file__).parent +REPO_ROOT_DIR = CURR_FILE_PATH.parent.parent + +HANDLER_PY = """ +from ts.handler_utils.utils import send_intermediate_predict_response + +def handle(data, context): + if type(data) is list: + for i in range (3): + send_intermediate_predict_response(["hello"], context.request_ids, "Intermediate Prediction success", 200, context) + return ["hello world "] + +""" + + +@pytest.fixture(scope="module") +def model_name(): + yield "tp_model" + + +@pytest.fixture(scope="module") +def work_dir(tmp_path_factory, model_name): + return Path(tmp_path_factory.mktemp(model_name)) + + +@pytest.fixture(scope="module", name="mar_file_path") +def create_mar_file(work_dir, model_archiver, model_name): + mar_file_path = work_dir.joinpath(model_name + ".mar") + + handler_py_file = work_dir / "handler.py" + handler_py_file.write_text(HANDLER_PY) + + config = ModelArchiverConfig( + model_name=model_name, + version="1.0", + serialized_file=None, + model_file=None, + handler=handler_py_file.as_posix(), + extra_files=None, + export_path=work_dir, + requirements_file=None, + runtime="python", + force=False, + archive_format="default", + config_file=None, + ) + + with patch("archiver.ArgParser.export_model_args_parser", return_value=config): + model_archiver.generate_model_archive() + + assert mar_file_path.exists() + + yield mar_file_path.as_posix() + + # Clean up files + mar_file_path.unlink(missing_ok=True) + + +@pytest.fixture(scope="module", name="model_name") +def register_model(mar_file_path, model_store, torchserve): + """ + Register the model in torchserve + """ + shutil.copy(mar_file_path, model_store) + + file_name = Path(mar_file_path).name + + model_name = Path(file_name).stem + + params = ( + ("model_name", model_name), + ("url", file_name), + ("initial_workers", "1"), + ("synchronous", "true"), + ("batch_size", "1"), + ) + + test_utils.reg_resp = test_utils.register_model_with_params(params) + + yield model_name + + test_utils.unregister_model(model_name) + + +@pytest.mark.parametrize(("params"), ((True, 4), (False, 1))) +def test_echo_stream_inference(model_name, params): + """ + Full circle test with torchserve + """ + STREAM = params[0] + EXPECTED_RESPONSES = params[1] + + response = requests.post( + url=f"http://localhost:8080/predictions/{model_name}", + data=json.dumps(42), + stream=STREAM, + ) + + assert response.status_code == 200 + + assert response.headers["Transfer-Encoding"] == "chunked" + + prediction = [] + for chunk in response.iter_content(chunk_size=None): + if chunk: + prediction += [chunk.decode("utf-8")] + + assert len(prediction) == EXPECTED_RESPONSES + + assert str("".join(prediction)) == "hellohellohellohello world "