diff --git a/flash/core/serve/flash_components.py b/flash/core/serve/flash_components.py index c02b182039..c7370f93d8 100644 --- a/flash/core/serve/flash_components.py +++ b/flash/core/serve/flash_components.py @@ -2,6 +2,7 @@ from typing import Any, Callable, Mapping import torch +from torch import Tensor from flash.core.data.batch import _ServeInputProcessor from flash.core.data.data_module import DataModule @@ -41,6 +42,8 @@ def serialize(self, outputs) -> Any: # pragma: no cover result = self._output(output) if isinstance(result, Mapping): result = result[DataKeys.PREDS] + if isinstance(result, Tensor): + result = result.tolist() results.append(result) if len(results) == 1: return results[0] diff --git a/flash_examples/serve/image_classification/inference_server.py b/flash_examples/serve/image_classification/inference_server.py index 2f9e06e50a..4d57fa96ab 100644 --- a/flash_examples/serve/image_classification/inference_server.py +++ b/flash_examples/serve/image_classification/inference_server.py @@ -16,4 +16,4 @@ model = ImageClassifier.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/0.7.0/image_classification_model.pt" ) -model.serve() +model.serve(output="labels") diff --git a/requirements/serve.txt b/requirements/serve.txt index da47094178..4f39f4e22f 100644 --- a/requirements/serve.txt +++ b/requirements/serve.txt @@ -4,16 +4,10 @@ pyyaml cytoolz graphviz tqdm -# until 1.0 release fastapi docs recommend pinning to MINOR releases. -# https://fastapi.tiangolo.com/deployment/#fastapi-versions -fastapi>=0.65.2,<0.66.0 -# to have full feature control of fastapi, manually install optional -# dependencies rather than installing fastapi[all] -# https://fastapi.tiangolo.com/#optional-dependencies -pydantic>1.8.1,<2.0.0 +fastapi>=0.65.2 +pydantic>1.8.1 starlette==0.14.2 -uvicorn[standard]>=0.12.0,<0.14.0 +uvicorn[standard]>=0.12.0 aiofiles jinja2 -importlib-metadata>=0.12,<3;python_version<"3.8" torchvision