Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding mps support to base handler and regression test #3048

Merged
merged 20 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.lang.reflect.Field;
import java.lang.reflect.Type;
import java.net.InetAddress;
Expand Down Expand Up @@ -835,6 +837,28 @@ private static int getAvailableGpu() {
for (String id : ids) {
gpuIds.add(Integer.parseInt(id));
}
} else if (System.getProperty("os.name").startsWith("Mac")) {
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
Process process = Runtime.getRuntime().exec("system_profiler SPDisplaysDataType");
int ret = process.waitFor();
if (ret != 0) {
return 0;
}

BufferedReader reader =
new BufferedReader(new InputStreamReader(process.getInputStream()));
String line;
while ((line = reader.readLine()) != null) {
if (line.contains("Chipset Model:") && !line.contains("Apple M1")) {
return 0;
}
if (line.contains("Total Number of Cores:")) {
String[] parts = line.split(":");
if (parts.length >= 2) {
return (Integer.parseInt(parts[1].trim()));
}
}
}
throw new AssertionError("Unexpected response.");
} else {
Process process =
Runtime.getRuntime().exec("nvidia-smi --query-gpu=index --format=csv");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,18 @@ public void testNoWorkflowState() throws ReflectiveOperationException, IOExcepti
workingDir + "/frontend/archive/src/test/resources/models",
configManager.getWorkflowStore());
}

@Test
public void testNumGpuM1() throws ReflectiveOperationException, IOException {
System.setProperty("tsConfigFile", "src/test/resources/config_test_env.properties");
ConfigManager.Arguments args = new ConfigManager.Arguments();
args.setModels(new String[] {"noop_v0.1"});
args.setSnapshotDisabled(true);
ConfigManager.init(args);
ConfigManager configManager = ConfigManager.getInstance();
String arch = System.getProperty("os.arch");
if (arch.equals("aarch64")) {
Assert.assertTrue(configManager.getNumberOfGpu() > 0);
}
}
}
181 changes: 181 additions & 0 deletions test/pytest/test_device_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
import shutil
from pathlib import Path
from unittest.mock import patch
import tempfile

import pytest
import test_utils
import requests
import os
import platform
from model_archiver import ModelArchiverConfig




CURR_FILE_PATH = Path(__file__).parent
REPO_ROOT_DIR = CURR_FILE_PATH.parent.parent
ROOT_DIR = os.path.join(tempfile.gettempdir(), "workspace")
REPO_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../")
data_file_zero = os.path.join(REPO_ROOT, "test/pytest/test_data/0.png")
config_file = os.path.join(REPO_ROOT, "test/resources/config_token.properties")
mnist_scriptes_py = os.path.join(REPO_ROOT,"examples/image_classifier/mnist/mnist.py")

HANDLER_PY = """
from ts.torch_handler.base_handler import BaseHandler

class deviceHandler(BaseHandler):

def initialize(self, context):
super().initialize(context)
assert self.get_device().type == "mps"
"""

MODEL_CONFIG_YAML = """
#frontend settings
# TorchServe frontend parameters
minWorkers: 1
batchSize: 4
maxWorkers: 4
"""

MODEL_CONFIG_YAML_GPU = """
#frontend settings
# TorchServe frontend parameters
minWorkers: 1
batchSize: 4
maxWorkers: 4
deviceType: "gpu"
"""

MODEL_CONFIG_YAML_CPU = """
#frontend settings
# TorchServe frontend parameters
minWorkers: 1
batchSize: 4
maxWorkers: 4
deviceType: "cpu"
"""


@pytest.fixture(scope="module")
def model_name():
yield "mnist"

@pytest.fixture(scope="module")
def work_dir(tmp_path_factory, model_name):
return Path(tmp_path_factory.mktemp(model_name))

@pytest.fixture(scope="module")
def model_config_name(request):
def get_config(param):
if param == "cpu":
return MODEL_CONFIG_YAML_CPU
elif param == "gpu":
return MODEL_CONFIG_YAML_GPU
else:
return MODEL_CONFIG_YAML

return get_config(request.param)

@pytest.fixture(scope="module", name="mar_file_path")
def create_mar_file(work_dir, model_archiver, model_name, model_config_name):


mar_file_path = work_dir.joinpath(model_name + ".mar")

model_config_yaml_file = work_dir / "model_config.yaml"
model_config_yaml_file.write_text(model_config_name)

model_py_file = work_dir / "model.py"

model_py_file.write_text(mnist_scriptes_py)

handler_py_file = work_dir / "handler.py"
handler_py_file.write_text(HANDLER_PY)

config = ModelArchiverConfig(
model_name=model_name,
version="1.0",
serialized_file=None,
model_file=mnist_scriptes_py, #model_py_file.as_posix(),
handler=handler_py_file.as_posix(),
extra_files=None,
export_path=work_dir,
requirements_file=None,
runtime="python",
force=False,
archive_format="default",
config_file=model_config_yaml_file.as_posix(),
)

with patch("archiver.ArgParser.export_model_args_parser", return_value=config):
model_archiver.generate_model_archive()

assert mar_file_path.exists()

yield mar_file_path.as_posix()

# Clean up files

mar_file_path.unlink(missing_ok=True)

# Clean up files

@pytest.fixture(scope="module", name="model_name")
def register_model(mar_file_path, model_store, torchserve):
"""
Register the model in torchserve
"""
shutil.copy(mar_file_path, model_store)

file_name = Path(mar_file_path).name

model_name = Path(file_name).stem

params = (
("model_name", model_name),
("url", file_name),
("initial_workers", "1"),
("synchronous", "true"),
("batch_size", "1"),
)

test_utils.reg_resp = test_utils.register_model_with_params(params)

yield model_name

test_utils.unregister_model(model_name)


@pytest.mark.skipif(platform.machine() != "arm64", reason="Skip on Mac M1")
@pytest.mark.parametrize("model_config_name", ["gpu"], indirect=True)
def test_m1_device(model_name, model_config_name):
udaij12 marked this conversation as resolved.
Show resolved Hide resolved

response = requests.get(f"http://localhost:8081/models/{model_name}")

print("-----TEST-----")
print(response.content)
assert response.status_code == 200, "Describe worked"


@pytest.mark.skipif(platform.machine() != "arm64", reason="Skip on Mac M1")
@pytest.mark.parametrize("model_config_name", ["cpu"], indirect=True)
def test_m1_device_cpu(model_name, model_config_name):

response = requests.get(f"http://localhost:8081/models/{model_name}")

print("-----TEST-----")
print(response.content)
assert response.status_code == 404, "Describe worked"


@pytest.mark.skipif(platform.machine() != "arm64", reason="Skip on Mac M1")
@pytest.mark.parametrize("model_config_name", ["default"], indirect=True)
def test_m1_device_default(model_name, model_config_name):

response = requests.get(f"http://localhost:8081/models/{model_name}")

print("-----TEST-----")
print(response.content)
assert response.status_code == 200, "Describe worked"
1 change: 1 addition & 0 deletions test/resources/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mps: "enable"
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
14 changes: 13 additions & 1 deletion ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class BaseHandler(abc.ABC):
Base default handler to load torchscript or eager mode [state_dict] models
Also, provides handle method per torch serve custom model specification
"""

def __init__(self):
self.model = None
self.mapping = None
Expand Down Expand Up @@ -144,11 +144,15 @@ def initialize(self, context):
self.model_yaml_config = context.model_yaml_config

properties = context.system_properties

if torch.cuda.is_available() and properties.get("gpu_id") is not None:
self.map_location = "cuda"
self.device = torch.device(
self.map_location + ":" + str(properties.get("gpu_id"))
)
elif torch.backends.mps.is_available() and properties.get("gpu_id") is not None:
udaij12 marked this conversation as resolved.
Show resolved Hide resolved
self.map_location = "mps"
self.device = torch.device("mps")
elif XLA_AVAILABLE:
self.device = xm.xla_device()
else:
Expand Down Expand Up @@ -524,3 +528,11 @@ def describe_handle(self):
# pylint: disable=unnecessary-pass
pass
# pylint: enable=unnecessary-pass

def get_device(self):
"""Get device

Returns:
string : self device
"""
return self.device
Loading