Skip to content

Commit

Permalink
Merge pull request #25 from philschmid/add-error-for-csv-without-header
Browse files Browse the repository at this point in the history
Add exception for CSV without headers.
  • Loading branch information
philschmid authored Aug 23, 2021
2 parents 2fb97f9 + 4a74cf1 commit 2743a73
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 16 deletions.
28 changes: 16 additions & 12 deletions src/sagemaker_huggingface_inference_toolkit/decoder_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import datetime
import json
from io import StringIO
import csv

import numpy as np
from sagemaker_inference.decoder import (
_npy_to_numpy,
_npz_to_sparse,
)
from sagemaker_inference.encoder import (
_array_to_npy,
)
from sagemaker_inference import (
content_types,
errors,
)
from sagemaker_inference import content_types, errors
from sagemaker_inference.decoder import _npy_to_numpy, _npz_to_sparse
from sagemaker_inference.encoder import _array_to_npy

from mms.service import PredictionException


def decode_json(content):
Expand All @@ -42,6 +37,13 @@ def decode_csv(string_like): # type: (str) -> np.array
(dict): dictonatry for input
"""
stream = StringIO(string_like)
# detects if the incoming csv has headers
if not any(header in string_like.splitlines()[0].lower() for header in ["question", "context", "inputs"]):
raise PredictionException(
f"You need to provide the correct CSV with Header columns to use it with the inference toolkit default handler.",
400,
)
# reads csv as io
request_list = list(csv.DictReader(stream))
if "inputs" in request_list[0].keys():
return {"inputs": [entry["inputs"] for entry in request_list]}
Expand Down Expand Up @@ -123,6 +125,8 @@ def decode(content, content_type=content_types.JSON):
return decoder(content)
except KeyError:
raise errors.UnsupportedFormatError(content_type)
except PredictionException as pred_err:
raise pred_err


def encode(content, content_type=content_types.JSON):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
import importlib
import logging
import os
import time
import sys
import time
from abc import ABC

from sagemaker_inference import environment, utils, content_types
from sagemaker_inference import content_types, environment, utils
from transformers.pipelines import SUPPORTED_TASKS

from mms.service import PredictionException
from mms import metrics
from mms.service import PredictionException
from sagemaker_huggingface_inference_toolkit import decoder_encoder
from sagemaker_huggingface_inference_toolkit.transformers_utils import (
_is_gpu_available,
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/test_decoder_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.
import json

import pytest

from mms.service import PredictionException
from sagemaker_huggingface_inference_toolkit import decoder_encoder


Expand Down Expand Up @@ -46,6 +49,13 @@ def test_decode_csv():
assert decoded_data == {"inputs": ["I love you", "I like you"]}


def test_decode_csv_without_header():
with pytest.raises(PredictionException):
decoder_encoder.decode_csv(
"where do i live?,My name is Philipp and I live in Nuremberg\r\nwhere is Berlin?,Berlin is the capital of Germany"
)


def test_encode_json():
encoded_data = decoder_encoder.encode_json(ENCODE_JSON_INPUT)
assert json.loads(encoded_data) == ENCODE_JSON_INPUT
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_handler_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
import tempfile

import pytest
from sagemaker_inference import content_types
from transformers.testing_utils import require_torch, slow

from mms.context import Context, RequestProcessor
from mms.metrics.metrics_store import MetricsStore
from sagemaker_huggingface_inference_toolkit import handler_service
from sagemaker_huggingface_inference_toolkit.transformers_utils import _load_model_from_hub, get_pipeline
from sagemaker_inference import content_types


TASK = "text-classification"
MODEL = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english"
Expand Down

0 comments on commit 2743a73

Please sign in to comment.