Skip to content

Commit

Permalink
Merge pull request #1311 from pytorch/issue_1309
Browse files Browse the repository at this point in the history
support custom return code
  • Loading branch information
lxning authored Nov 10, 2021
2 parents 433584f + 0d29576 commit d258da8
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 0 deletions.
27 changes: 27 additions & 0 deletions docs/custom_service.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,33 @@ class ModelHandler(object):

### Advanced custom handlers

#### Returning custom error codes

To return a custom error code back to the user via custom handler with `module` level entry point.

```python
from ts.utils.util import PredictionException
def handle(data, context):
# Some unexpected error - returning error code 513
raise PredictionException("Some Prediction Error", 513)
```

To return a custom error code back to the user via custom handler with `class` level entry point.

```python
from ts.torch_handler.base_handler import BaseHandler
from ts.utils.util import PredictionException

class ModelHandler(BaseHandler):
"""
A custom model handler implementation.
"""

def handle(self, data, context):
# Some unexpected error - returning error code 513
raise PredictionException("Some Prediction Error", 513)
```

#### Writing a custom handler from scratch for Prediction and Explanations Request

*You should generally derive from BaseHandler and ONLY override methods whose behavior needs to change!* As you can see in the examples, most of the time you only need to override `preprocess` or `postprocess`
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,51 @@ public void testPredictionMemoryError() throws InterruptedException {
@Test(
alwaysRun = true,
dependsOnMethods = {"testPredictionMemoryError"})
public void testPredictionCustomErrorCode() throws InterruptedException {
// Load the model
Channel channel = TestUtils.connect(ConnectorType.MANAGEMENT_CONNECTOR, configManager);
Assert.assertNotNull(channel);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));

TestUtils.registerModel(
channel, "pred-custom-return-code.mar", "pred-custom-return-code", true, false);
TestUtils.getLatch().await();
Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
channel.close().sync();

// Test for prediction
channel = TestUtils.connect(ConnectorType.INFERENCE_CONNECTOR, configManager);
Assert.assertNotNull(channel);
TestUtils.setResult(null);
TestUtils.setLatch(new CountDownLatch(1));
DefaultFullHttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1,
HttpMethod.POST,
"/predictions/pred-custom-return-code");
req.content().writeCharSequence("data=invalid_output", CharsetUtil.UTF_8);

channel.writeAndFlush(req);
TestUtils.getLatch().await();

Assert.assertEquals(TestUtils.getHttpStatus().code(), 599);
channel.close().sync();

// Unload the model
channel = TestUtils.connect(ConnectorType.MANAGEMENT_CONNECTOR, configManager);
TestUtils.setHttpStatus(null);
TestUtils.setLatch(new CountDownLatch(1));
Assert.assertNotNull(channel);

TestUtils.unregisterModel(channel, "pred-custom-return-code", null, false);
TestUtils.getLatch().await();
Assert.assertEquals(TestUtils.getHttpStatus(), HttpResponseStatus.OK);
}

@Test(
alwaysRun = true,
dependsOnMethods = {"testPredictionCustomErrorCode"})
public void testErrorBatch() throws InterruptedException {
Channel channel = TestUtils.connect(ConnectorType.MANAGEMENT_CONNECTOR, configManager);
Assert.assertNotNull(channel);
Expand Down
4 changes: 4 additions & 0 deletions ts/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from ts.context import Context, RequestProcessor
from ts.metrics.metrics_store import MetricsStore
from ts.protocol.otf_message_handler import create_predict_response
from ts.utils.util import PredictionException

PREDICTION_METRIC = 'PredictionTime'
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -99,6 +100,9 @@ def predict(self, batch):
# noinspection PyBroadException
try:
ret = self._entry_point(input_batch, self.context)
except PredictionException as e:
logger.error("Prediction error", exc_info=True)
return create_predict_response(None, req_id_map, e.message, e.error_code)
except MemoryError:
logger.error("System out of memory", exc_info=True)
return create_predict_response(None, req_id_map, "Out of resources", 507)
Expand Down
9 changes: 9 additions & 0 deletions ts/utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,12 @@ def map_class_to_label(probs, mapping=None, lbl_classes=None):
]

return results

class PredictionException(Exception):
def __init__(self, message, error_code=500):
self.message = message
self.error_code = error_code
super().__init__(message)

def __str__(self):
return "message : error_code".format(message=self.message, error_code=self.error_code)

0 comments on commit d258da8

Please sign in to comment.