From d609b0e09eeb5bbbe46bafddb7ca797bab9dee68 Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Sat, 16 Nov 2024 16:16:11 -0800 Subject: [PATCH] [lora] Add adapter OOM unit tests --- .../src/test/resources/adaptecho/model.py | 16 ++--- .../java/ai/djl/serving/ModelServerTest.java | 62 +++++++++++++++++++ 2 files changed, 70 insertions(+), 8 deletions(-) diff --git a/engines/python/src/test/resources/adaptecho/model.py b/engines/python/src/test/resources/adaptecho/model.py index 84fe4fb83..3bf4c9760 100644 --- a/engines/python/src/test/resources/adaptecho/model.py +++ b/engines/python/src/test/resources/adaptecho/model.py @@ -23,8 +23,9 @@ def register_adapter(inputs: Input): global adapters name = inputs.get_property("name") if inputs.contains_key("error"): - return Output().error(f"error", - message=f"Failed to register adapter: {name}") + raise ValueError(f"Failed to register adapter: {name}") + if inputs.contains_key("oom"): + raise MemoryError adapters[name] = inputs return Output(message=f"Adapter {name} registered") @@ -33,11 +34,11 @@ def update_adapter(inputs: Input): global adapters name = inputs.get_property("name") if name not in adapters: - return Output().error(f"error", - message=f"Adapter {name} not registered.") + raise ValueError(f"Adapter {name} not registered.") if inputs.contains_key("error"): - return Output().error(f"error", - message=f"Failed to update adapter: {name}") + raise ValueError(f"Failed to update adapter: {name}") + if inputs.contains_key("oom"): + raise MemoryError adapters[name] = inputs return Output(message=f"Adapter {name} updated") @@ -46,8 +47,7 @@ def unregister_adapter(inputs: Input): global adapters name = inputs.get_property("name") if name not in adapters: - return Output().error(f"error", - message=f"Adapter {name} not registered.") + raise ValueError(f"Adapter {name} not registered.") del adapters[name] return Output(message=f"Adapter {name} unregistered") diff --git a/serving/src/test/java/ai/djl/serving/ModelServerTest.java b/serving/src/test/java/ai/djl/serving/ModelServerTest.java index 29ac2add2..df4575375 100644 --- a/serving/src/test/java/ai/djl/serving/ModelServerTest.java +++ b/serving/src/test/java/ai/djl/serving/ModelServerTest.java @@ -297,9 +297,11 @@ public void test() testRegisterAdapterConflict(); testRegisterAdapterModelNotFound(); testRegisterAdapterHandlerError(); + testRegisterAdapterOom(); testUpdateAdapterModelNotFound(); testUpdateAdapterNotFound(); testUpdateAdapterHandlerError(); + testUpdateAdapterOom(); testListAdapterModelNotFound(); testDescribeAdapterModelNotFound(); testDescribeAdapterNotFound(); @@ -1080,6 +1082,32 @@ private void testRegisterAdapterHandlerError() throws InterruptedException { assertFalse(resp.getAdapters().stream().anyMatch(a -> "adaptable2".equals(a.getName()))); } + private void testRegisterAdapterOom() throws InterruptedException { + logTestFunction(); + Channel channel = connect(Connector.ConnectorType.MANAGEMENT); + assertNotNull(channel); + + String modelName = "adaptecho"; + String adapterName = "adaptable2"; + String strModelPrefix = "/models/" + modelName; + String url = strModelPrefix + "/adapters?name=" + adapterName + "&src=src&oom=true"; + request(channel, HttpMethod.POST, url); + channel.closeFuture().sync(); + channel.close().sync(); + assertHttpCode(HttpResponseStatus.INSUFFICIENT_STORAGE.code()); + + // Assert adapter not added + channel = connect(Connector.ConnectorType.MANAGEMENT); + assertNotNull(channel); + + url = strModelPrefix + "/adapters"; + request(channel, HttpMethod.GET, url); + assertHttpOk(); + + ListAdaptersResponse resp = JsonUtils.GSON.fromJson(result, ListAdaptersResponse.class); + assertFalse(resp.getAdapters().stream().anyMatch(a -> "adaptable2".equals(a.getName()))); + } + private void testUpdateAdapter(Channel channel, boolean modelPrefix) throws InterruptedException { logTestFunction(); @@ -1166,6 +1194,40 @@ private void testUpdateAdapterHandlerError() throws InterruptedException { assertFalse(resp.isPin()); } + private void testUpdateAdapterOom() throws InterruptedException { + logTestFunction(); + Channel channel = connect(Connector.ConnectorType.MANAGEMENT); + assertNotNull(channel); + + String modelName = "adaptecho"; + String adapterName = "adaptable"; + String strModelPrefix = "/models/" + modelName; + String url = + strModelPrefix + + "/adapters/" + + adapterName + + "/update?src=src1&load=false&oom=true"; + request(channel, HttpMethod.POST, url); + channel.closeFuture().sync(); + channel.close().sync(); + assertHttpCode(HttpResponseStatus.INSUFFICIENT_STORAGE.code()); + + // Assert adapter not updated + channel = connect(Connector.ConnectorType.MANAGEMENT); + assertNotNull(channel); + + url = strModelPrefix + "/adapters/" + adapterName; + request(channel, HttpMethod.GET, url); + assertHttpOk(); + + DescribeAdapterResponse resp = + JsonUtils.GSON.fromJson(result, DescribeAdapterResponse.class); + assertEquals(resp.getName(), adapterName); + assertEquals(resp.getSrc(), "src"); + assertTrue(resp.isLoad()); + assertFalse(resp.isPin()); + } + private void testAdapterMissing() throws InterruptedException { logTestFunction(); Channel channel = connect(Connector.ConnectorType.INFERENCE);