Skip to content

Commit

Permalink
feat: inference server option returns base64 image
Browse files Browse the repository at this point in the history
Passing `{"server": {"base64": true}}` option on inference api predict
request will ask api server to return base64 image in request response.
  • Loading branch information
alx authored and beniz committed Jan 18, 2024
1 parent 76ce6b4 commit 5bc8f44
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 1 deletion.
28 changes: 27 additions & 1 deletion server/joligen_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import shutil
from pathlib import Path
import time
import base64

import torch.multiprocessing as mp

Expand Down Expand Up @@ -331,7 +332,32 @@ async def predict(request: Request):
# run in synchronous mode
try:
ctx[opt.name].join()
return {"message": "ok", "name": opt.name, "status": "stopped"}
message = {"message": "ok", "name": opt.name, "status": "stopped"}

if (
"base64" in predict_body["server"]
and predict_body["server"]["base64"] == True
):

if hasattr(opt, "img_out") and opt.img_out is not None:

with open(opt.img_out, "rb") as f:
message["base64"] = [base64.b64encode(f.read())]

elif hasattr(opt, "dir_out") and opt.dir_out is not None:

message["base64"] = []
for sample_index in range(opt.nb_samples):
for output in ["cond", "generated", "orig", "y_t"]:
img_out = os.path.join(
opt.dir_out,
f"%s_%i_%s.png" % (opt.name, sample_index, output),
)
with open(img_out, "rb") as f:
message["base64"].append(base64.b64encode(f.read()))

return message

except Exception as e:
raise HTTPException(status_code=400, detail="{0}".format(e))

Expand Down
56 changes: 56 additions & 0 deletions tests/test_api_predict_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastapi.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
from PIL import Image
import base64

sys.path.append(sys.path[0] + "/..")
from server.joligen_api import app
Expand Down Expand Up @@ -196,3 +197,58 @@ def test_predict_endpoint_sync_success(dataroot, api):
os.remove(img_out)

os.remove(img_resized)


def test_predict_endpoint_sync_base64(dataroot, api):

name = "joligen_utest_api_palette"
dir_model = "/".join(dataroot.split("/")[:-1])

if not os.path.exists(dir_model):
pytest.fail("Model does not exist")

model_in_file = os.path.abspath(os.path.join(dir_model, name, "latest_net_G_A.pth"))

if not os.path.exists(model_in_file):
pytest.fail(f"Model file does not exist: %s" % model_in_file)

img_in = os.path.join(dataroot, "trainA", "img", "00000.png")

if not os.path.exists(img_in):
pytest.fail(f"Image input file does not exist: %s" % img_in)

payload = {
"predict_options": {
"model_in_file": model_in_file,
"img_in": img_in,
"dir_out": dir_model,
},
"server": {"sync": True, "base64": True},
}

response = api.post("/predict", json=payload)
assert response.status_code == 200

json_response = response.json()
predict_name = json_response["name"]

assert "message" in json_response
assert "status" in json_response
assert "name" in json_response
assert json_response["message"] == "ok"
assert json_response["status"] == "stopped"
assert json_response["name"].startswith("predict_")
assert len(json_response["name"]) > 0

assert len(json_response["base64"]) == 4
for index, output in enumerate(["cond", "generated", "orig", "y_t"]):

img_out = os.path.join(dir_model, f"%s_0_%s.png" % (predict_name, output))
assert os.path.exists(img_out)

with open(img_out, "rb") as f:
base64_out = base64.b64encode(f.read()).decode("utf-8")
assert base64_out == json_response["base64"][index]

if os.path.exists(img_out):
os.remove(img_out)
56 changes: 56 additions & 0 deletions tests/test_api_predict_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import shutil
from fastapi.testclient import TestClient
from starlette.websockets import WebSocketDisconnect
import base64

sys.path.append(sys.path[0] + "/..")
from server.joligen_api import app
Expand Down Expand Up @@ -181,3 +182,58 @@ def test_predict_endpoint_sync_success(dataroot, api):
assert os.path.exists(img_out)
if os.path.exists(img_out):
os.remove(img_out)


def test_predict_endpoint_sync_base64(dataroot, api):

name = "joligen_utest_api_cut"
dir_model = "/".join(dataroot.split("/")[:-1])

if not os.path.exists(dir_model):
pytest.fail("Model does not exist")

model_in_file = os.path.abspath(os.path.join(dir_model, name, "latest_net_G_A.pth"))

if not os.path.exists(model_in_file):
pytest.fail(f"Model file does not exist: %s" % model_in_file)

img_in = os.path.join(dataroot, "trainA", "img", "00000.png")

if not os.path.exists(img_in):
pytest.fail(f"Image input file does not exist: %s" % img_in)

img_out = os.path.abspath(os.path.join(dir_model, "out_success_sync.jpg"))

if os.path.exists(img_out):
os.remove(img_out)

payload = {
"predict_options": {
"model_in_file": model_in_file,
"img_in": img_in,
"img_out": img_out,
},
"server": {"sync": True, "base64": True},
}

response = api.post("/predict", json=payload)
assert response.status_code == 200

json_response = response.json()
assert "message" in json_response
assert "status" in json_response
assert "name" in json_response
assert json_response["message"] == "ok"
assert json_response["status"] == "stopped"
assert json_response["name"].startswith("predict_")
assert len(json_response["name"]) > 0

assert os.path.exists(img_out)

assert len(json_response["base64"]) == 1
with open(img_out, "rb") as f:
base64_out = base64.b64encode(f.read()).decode("utf-8")
assert base64_out == json_response["base64"][0]

if os.path.exists(img_out):
os.remove(img_out)

0 comments on commit 5bc8f44

Please sign in to comment.