Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor sanity checks to use pytest #2221

Merged
merged 9 commits into from
Mar 11, 2024
27 changes: 27 additions & 0 deletions test/pytest/sanity/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import json
import sys
from pathlib import Path

import pytest

REPO_ROOT = Path(__file__).parents[3]


MAR_CONFIG = REPO_ROOT.joinpath("ts_scripts", "mar_config.json")


@pytest.fixture(name="gen_models", scope="module")
def load_gen_models() -> dict:
with open(MAR_CONFIG) as f:
models = json.load(f)
models = {m["model_name"]: m for m in models}
return models


@pytest.fixture(scope="module")
def ts_scripts_path():
sys.path.append(REPO_ROOT.as_posix())

yield

sys.path.pop()
54 changes: 54 additions & 0 deletions test/pytest/sanity/test_config_snapshot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import json
from pathlib import Path

import pytest
import test_utils

REPO_ROOT = Path(__file__).parents[3]
SANITY_MODELS_CONFIG = REPO_ROOT.joinpath("ts_scripts", "configs", "sanity_models.json")


def load_resnet18() -> dict:
with open(SANITY_MODELS_CONFIG) as f:
models = json.load(f)
return list(filter(lambda x: x["name"] == "resnet-18", models))[0]


@pytest.fixture(name="resnet18")
def generate_resnet18(model_store, gen_models, ts_scripts_path):
model = load_resnet18()

from ts_scripts.marsgen import generate_model

generate_model(gen_models[model["name"]], model_store)

yield model


@pytest.fixture(scope="module")
def torchserve_with_snapshot(model_store):
test_utils.torchserve_cleanup()

test_utils.start_torchserve(
model_store=model_store, no_config_snapshots=False, gen_mar=False
)

yield

test_utils.torchserve_cleanup()


def test_config_snapshotting(
resnet18, model_store, torchserve_with_snapshot, ts_scripts_path
):
from ts_scripts.sanity_utils import run_rest_test

run_rest_test(resnet18, unregister_model=False)

test_utils.stop_torchserve()

test_utils.start_torchserve(
model_store=model_store, no_config_snapshots=False, gen_mar=False
)

run_rest_test(resnet18, register_model=False)
55 changes: 55 additions & 0 deletions test/pytest/sanity/test_model_registering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import json
from pathlib import Path

import pytest

REPO_ROOT = Path(__file__).parents[3]
SANITY_MODELS_CONFIG = REPO_ROOT.joinpath("ts_scripts", "configs", "sanity_models.json")


@pytest.fixture(scope="module")
def grpc_client_stubs(ts_scripts_path):
from ts_scripts.shell_utils import rm_file
from ts_scripts.tsutils import generate_grpc_client_stubs

generate_grpc_client_stubs()

yield

rm_file(REPO_ROOT.joinpath("ts_scripts", "*_pb2*.py").as_posix(), True)


def load_models() -> dict:
with open(SANITY_MODELS_CONFIG) as f:
models = json.load(f)
return models


@pytest.fixture(name="model", params=load_models(), scope="module")
def models_to_validate(request, model_store, gen_models, ts_scripts_path):
model = request.param

if model["name"] in gen_models:
from ts_scripts.marsgen import generate_model

generate_model(gen_models[model["name"]], model_store)

yield model


def test_models_with_grpc(model, torchserve, ts_scripts_path, grpc_client_stubs):
from ts_scripts.sanity_utils import run_grpc_test

run_grpc_test(model)


def test_models_with_rest(model, torchserve, ts_scripts_path):
from ts_scripts.sanity_utils import run_rest_test

run_rest_test(model)


def test_gpu_setup(ts_scripts_path):
from ts_scripts.sanity_utils import test_gpu_setup

test_gpu_setup()
125 changes: 59 additions & 66 deletions ts_scripts/marsgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,64 @@ def gen_mar(model_store=None):
print(f"## Symlink {src}, {dst} successfully.")


def generate_model(model, model_store_dir):
serialized_file_path = None
if model.get("serialized_file_remote", None):
if model.get("gen_scripted_file_path", None):
subprocess.run(["python", model["gen_scripted_file_path"]])
else:
serialized_model_file_url = (
f"https://download.pytorch.org/models/{model['serialized_file_remote']}"
)
urllib.request.urlretrieve(
serialized_model_file_url,
f'{model_store_dir}/{model["serialized_file_remote"]}',
)
serialized_file_path = os.path.join(
model_store_dir, model["serialized_file_remote"]
)
elif model.get("serialized_file_local", None):
serialized_file_path = model["serialized_file_local"]

handler = model.get("handler", None)

extra_files = model.get("extra_files", None)

runtime = model.get("runtime", None)

archive_format = model.get("archive_format", "zip-store")

requirements_file = model.get("requirements_file", None)

export_path = model.get("export_path", model_store_dir)

cmd = model_archiver_command_builder(
model["model_name"],
model["version"],
model.get("model_file", None),
serialized_file_path,
handler,
extra_files,
runtime,
archive_format,
requirements_file,
export_path,
)
print(f"## In directory: {os.getcwd()} | Executing command: {cmd}\n")
try:
subprocess.check_call(cmd, shell=True)
marfile = "{}.mar".format(model["model_name"])
print("## {} is generated.\n".format(marfile))
mar_set.add(marfile)
except subprocess.CalledProcessError as exc:
print("## {} creation failed !, error: {}\n".format(model["model_name"], exc))

if model.get("serialized_file_remote", None) and os.path.exists(
serialized_file_path
):
os.remove(serialized_file_path)


def generate_mars(mar_config=MAR_CONFIG_FILE_PATH, model_store_dir=MODEL_STORE_DIR):
"""
By default generate_mars reads ts_scripts/mar_config.json and outputs mar files in dir model_store_gen
Expand All @@ -67,72 +125,7 @@ def generate_mars(mar_config=MAR_CONFIG_FILE_PATH, model_store_dir=MODEL_STORE_D
models = json.loads(f.read())

for model in models:
serialized_file_path = None
if model.get("serialized_file_remote") and model["serialized_file_remote"]:
if (
model.get("gen_scripted_file_path")
and model["gen_scripted_file_path"]
):
subprocess.run(["python", model["gen_scripted_file_path"]])
else:
serialized_model_file_url = (
"https://download.pytorch.org/models/{}".format(
model["serialized_file_remote"]
)
)
urllib.request.urlretrieve(
serialized_model_file_url,
f'{model_store_dir}/{model["serialized_file_remote"]}',
)
serialized_file_path = os.path.join(
model_store_dir, model["serialized_file_remote"]
)
elif model.get("serialized_file_local") and model["serialized_file_local"]:
serialized_file_path = model["serialized_file_local"]

handler = model.get("handler", None)

extra_files = model.get("extra_files", None)

runtime = model.get("runtime", None)

archive_format = model.get("archive_format", "zip-store")

requirements_file = model.get("requirements_file", None)

export_path = model.get("export_path", model_store_dir)

cmd = model_archiver_command_builder(
model["model_name"],
model["version"],
model.get("model_file", None),
serialized_file_path,
handler,
extra_files,
runtime,
archive_format,
requirements_file,
export_path,
)
print(f"## In directory: {os.getcwd()} | Executing command: {cmd}\n")
try:
subprocess.check_call(cmd, shell=True)
marfile = "{}.mar".format(model["model_name"])
print("## {} is generated.\n".format(marfile))
mar_set.add(marfile)
except subprocess.CalledProcessError as exc:
print(
"## {} creation failed !, error: {}\n".format(
model["model_name"], exc
)
)

if (
model.get("serialized_file_remote")
and model["serialized_file_remote"]
and os.path.exists(serialized_file_path)
):
os.remove(serialized_file_path)
generate_model(model, model_store_dir)
os.chdir(cwd)


Expand Down
58 changes: 12 additions & 46 deletions ts_scripts/sanity_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from ts_scripts import marsgen as mg
from ts_scripts import tsutils as ts
from ts_scripts import utils
from ts_scripts.tsutils import generate_grpc_client_stubs

REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
sys.path.append(REPO_ROOT)
Expand Down Expand Up @@ -163,51 +162,18 @@ def run_rest_test(model, register_model=True, unregister_model=True):


def test_sanity():
generate_grpc_client_stubs()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess sanity test will fail without gRPC client.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats call is now a fixture

generate_grpc_client_stubs()


print("## Started sanity tests")

models_to_validate = load_model_to_validate()

test_gpu_setup()

ts_log_file = os.path.join("logs", "ts_console.log")

os.makedirs("model_store", exist_ok=True)
os.makedirs("logs", exist_ok=True)

mg.mar_set = set(os.listdir("model_store"))
started = ts.start_torchserve(log_file=ts_log_file, gen_mar=False)
if not started:
sys.exit(1)

resnet18_model = models_to_validate["resnet-18"]

models_to_validate = {
k: v for k, v in models_to_validate.items() if k != "resnet-18"
}

for _, model in models_to_validate.items():
run_grpc_test(model)
run_rest_test(model)

run_rest_test(resnet18_model, unregister_model=False)

stopped = ts.stop_torchserve()
if not stopped:
sys.exit(1)

# Restarting torchserve
# This should restart with the generated snapshot and resnet-18 model should be automatically registered
started = ts.start_torchserve(log_file=ts_log_file, gen_mar=False)
if not started:
sys.exit(1)

run_rest_test(resnet18_model, register_model=False)

stopped = ts.stop_torchserve()
if not stopped:
sys.exit(1)
# Execute python tests
print("## Started TorchServe sanity pytests")
test_dir = os.path.join("test", "pytest", "sanity")
coverage_dir = os.path.join("ts")
report_output_dir = os.path.join(test_dir, "coverage.xml")

ts_test_cmd = f"python -m pytest --cov-report xml:{report_output_dir} --cov={coverage_dir} {test_dir}"
print(f"## In directory: {os.getcwd()} | Executing command: {ts_test_cmd}")
ts_test_error_code = os.system(ts_test_cmd)

if ts_test_error_code != 0:
sys.exit("## TorchServe sanity test failed !")


def test_workflow_sanity():
Expand Down
Loading