Skip to content

Commit

Permalink
Merge 24651ea into e205e6b
Browse files Browse the repository at this point in the history
  • Loading branch information
mreso authored Jun 2, 2023
2 parents e205e6b + 24651ea commit 0d8281e
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 40 deletions.
26 changes: 21 additions & 5 deletions test/pytest/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import subprocess
import sys
import tempfile
import time
import threading
from os import path
from pathlib import Path
from subprocess import PIPE, STDOUT, Popen

import requests

Expand All @@ -21,6 +22,16 @@
CODEBUILD_WD = path.abspath(path.join(__file__, "../../.."))


class PrintPipeTillTheEnd(threading.Thread):
def __init__(self, pipe):
super().__init__()
self.pipe = pipe

def run(self):
for line in self.pipe.stdout:
print(line.decode("utf-8").strip())


def start_torchserve(
model_store=None, snapshot_file=None, no_config_snapshots=False, gen_mar=True
):
Expand All @@ -36,13 +47,18 @@ def start_torchserve(
if no_config_snapshots:
cmd.extend(["--no-config-snapshots"])
print(cmd)
subprocess.run(cmd)
time.sleep(10)

p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=STDOUT)
for line in p.stdout:
print(line.decode("utf8").strip())
if "Model server started" in str(line).strip():
break
print_thread = PrintPipeTillTheEnd(p)
print_thread.start()


def stop_torchserve():
subprocess.run(["torchserve", "--stop"])
time.sleep(10)
subprocess.run(["torchserve", "--stop", "--foreground"])


def delete_all_snapshots():
Expand Down
2 changes: 1 addition & 1 deletion ts/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def ts_parser():
parser.add_argument(
"--foreground",
help="Run the model server in foreground. If this option is disabled, the model server"
" will run in the background.",
" will run in the background. In combination with --stop the program wait for the model server to terminate.",
action="store_true",
)
parser.add_argument(
Expand Down
8 changes: 7 additions & 1 deletion ts/model_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,13 @@ def start():
try:
parent = psutil.Process(pid)
parent.terminate()
print("TorchServe has stopped.")
if args.foreground:
try:
parent.wait(timeout=60)
except psutil.TimeoutExpired:
print("Stopping TorchServe took too long.")
else:
print("TorchServe has stopped.")
except (OSError, psutil.Error):
print("TorchServe already stopped.")
os.remove(pid_file)
Expand Down
114 changes: 81 additions & 33 deletions ts_scripts/tsutils.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,105 @@
import os
import platform
import sys
import time
import threading
from pathlib import Path
from subprocess import PIPE, STDOUT, Popen

import requests

from ts_scripts import marsgen as mg

torchserve_command = {
"Windows": "torchserve.exe",
"Darwin": "torchserve",
"Linux": "torchserve"
"Linux": "torchserve",
}

torch_model_archiver_command = {
"Windows": "torch-model-archiver.exe",
"Darwin": "torch-model-archiver",
"Linux": "torch-model-archiver"
}
"Windows": "torch-model-archiver.exe",
"Darwin": "torch-model-archiver",
"Linux": "torch-model-archiver",
}

torch_workflow_archiver_command = {
"Windows": "torch-workflow-archiver.exe",
"Darwin": "torch-workflow-archiver",
"Linux": "torch-workflow-archiver"
}
"Windows": "torch-workflow-archiver.exe",
"Darwin": "torch-workflow-archiver",
"Linux": "torch-workflow-archiver",
}


class LogPipeTillTheEnd(threading.Thread):
def __init__(self, pipe, log_file):
super().__init__()
self.pipe = pipe
self.log_file = log_file

def run(self):
with open(self.log_file, "a") as f:
for line in self.pipe.stdout:
f.write(line.decode("utf-8"))


def start_torchserve(
ncs=False, model_store="model_store", workflow_store="",
models="", config_file="", log_file="", wait_for=10, gen_mar=True):
ncs=False,
model_store="model_store",
workflow_store="",
models="",
config_file="",
log_file="",
gen_mar=True,
):
if gen_mar:
mg.gen_mar(model_store)
print("## Starting TorchServe")
cmd = f"{torchserve_command[platform.system()]} --start --model-store={model_store}"
cmd = [f"{torchserve_command[platform.system()]}"]
cmd.append("--start")
cmd.append(f"--model-store={model_store}")
if models:
cmd += f" --models={models}"
cmd.append(f"--models={models}")
if workflow_store:
cmd += f" --workflow-store={workflow_store}"
cmd.append(f"--workflow-store={workflow_store}")
if ncs:
cmd += " --ncs"
cmd.append("--ncs")
if config_file:
cmd += f" --ts-config={config_file}"
cmd.append(f"--ts-config={config_file}")
if log_file:
print(f"## Console logs redirected to file: {log_file}")
cmd += f" >> {log_file}"
print(f"## In directory: {os.getcwd()} | Executing command: {cmd}")
status = os.system(cmd)
print(f"## In directory: {os.getcwd()} | Executing command: {' '.join(cmd)}")
p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=STDOUT)
if log_file:
Path(log_file).parent.absolute().mkdir(parents=True, exist_ok=True)
with open(log_file, "a") as f:
for line in p.stdout:
f.write(line.decode("utf-8"))
if "Model server started" in str(line).strip():
break
t = LogPipeTillTheEnd(p, log_file)
t.start()
else:
for line in p.stdout:
if "Model server started" in str(line).strip():
break

status = p.poll()
if status == 0:
print("## Successfully started TorchServe")
time.sleep(wait_for)
return True
else:
print("## TorchServe failed to start !")
return False


def stop_torchserve(wait_for=10):
def stop_torchserve():
print("## Stopping TorchServe")
cmd = f"{torchserve_command[platform.system()]} --stop"
cmd = [f"{torchserve_command[platform.system()]}"]
cmd.append("--stop")
print(f"## In directory: {os.getcwd()} | Executing command: {cmd}")
status = os.system(cmd)
p = Popen(cmd, stdin=PIPE, stdout=PIPE, stderr=STDOUT)

status = p.wait()
if status == 0:
print("## Successfully stopped TorchServe")
time.sleep(wait_for)
return True
else:
print("## TorchServe failed to stop !")
Expand All @@ -86,7 +125,9 @@ def register_model(model_name, protocol="http", host="localhost", port="8081"):
return response


def run_inference(model_name, file_name, protocol="http", host="localhost", port="8080", timeout=120):
def run_inference(
model_name, file_name, protocol="http", host="localhost", port="8080", timeout=120
):
print(f"## Running inference on {model_name} model")
url = f"{protocol}://{host}:{port}/predictions/{model_name}"
files = {"data": (file_name, open(file_name, "rb"))}
Expand All @@ -103,9 +144,11 @@ def unregister_model(model_name, protocol="http", host="localhost", port="8081")

def generate_grpc_client_stubs():
print("## Started generating gRPC clinet stubs")
cmd = "python -m grpc_tools.protoc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts " \
"--grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto " \
"frontend/server/src/main/resources/proto/management.proto"
cmd = (
"python -m grpc_tools.protoc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts "
"--grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto "
"frontend/server/src/main/resources/proto/management.proto"
)
status = os.system(cmd)
if status != 0:
print("Could not generate gRPC client stubs")
Expand All @@ -115,9 +158,7 @@ def generate_grpc_client_stubs():
def register_workflow(workflow_name, protocol="http", host="localhost", port="8081"):
print(f"## Registering {workflow_name} workflow")
model_zoo_url = "https://torchserve.s3.amazonaws.com"
params = (
("url", f"{model_zoo_url}/war_files/{workflow_name}.war"),
)
params = (("url", f"{model_zoo_url}/war_files/{workflow_name}.war"),)
url = f"{protocol}://{host}:{port}/workflows"
response = requests.post(url, params=params, verify=False)
return response
Expand All @@ -130,7 +171,14 @@ def unregister_workflow(workflow_name, protocol="http", host="localhost", port="
return response


def workflow_prediction(workflow_name, file_name, protocol="http", host="localhost", port="8080", timeout=120):
def workflow_prediction(
workflow_name,
file_name,
protocol="http",
host="localhost",
port="8080",
timeout=120,
):
print(f"## Running inference on {workflow_name} workflow")
url = f"{protocol}://{host}:{port}/wfpredict/{workflow_name}"
files = {"data": (file_name, open(file_name, "rb"))}
Expand Down

0 comments on commit 0d8281e

Please sign in to comment.