Skip to content

Commit

Permalink
feat: api inference for gan and diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
alx authored and Bycob committed Jan 15, 2024
1 parent 18bbe2b commit 6fd43d8
Show file tree
Hide file tree
Showing 24 changed files with 896 additions and 19 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ checkpoints/
results/
build/
dist/
logs/
*.png
torch.egg-info/
*/**/__pycache__
Expand Down
2 changes: 1 addition & 1 deletion client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
python client.py --method training_status --host jg_server_host --port jg_server_port
# Stop a training
python client.py --method training_status --host jg_server_host --port jg_server_port --name training_name
python client.py --method stop_training --host jg_server_host --port jg_server_port --name training_name
"""

import requests
Expand Down
124 changes: 123 additions & 1 deletion docs/source/server.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,126 @@ Client: stopping a training job

.. code:: bash
python client.py --method training_status --host jg_server_host --port jg_server_port --name training_name
python client.py --method stop_training --host jg_server_host --port jg_server_port --name training_name
.. _client_stop:

**********************
Curl: inference on GAN
**********************

Using the following `curl` command, it will create `$PATH_TO_FILES/output.jpg` file and return it inside `curl_response["base64"]` in a base64 formatted string.

Base64 string will be saved into `output.jpg` file.

.. code:: bash
PATH_TO_MODEL=/home/joligen/models/horse2zebra/
PATH_TO_FILES=/home/joligen/files/
JOLIGEN_SERVER=localhost:18100
BASE64_IMG = curl 'http://$JOLIGEN_SERVER/predict' -X POST \
--data-raw \
'{
"predict_options": {
"model_in_file": "$PATH_TO_MODEL/latest_net_G_A.pth",
"img_in": "$PATH_TO_FILES/source.jpg",
"img_out": "$PATH_TO_FILES/output.jpg"
},
"server": {
"sync": true,
"base64": true
}
}' | jq .base64[0]
cat $BASE64_IMG | base64 -d > output.jpg
If `payload["server"]["base64"]` is not enabled, file will be created on disk but it won't be returned inside `curl_response`.

If `payload["server"]["sync"]` is not enabled, the inference process will run in an asynchronous mode, `curl_response` will only return a message status stating that the process has started.

In `async` mode, process status can be followed using websocket:

.. code:: bash
JOLIGEN_SERVER=localhost:18100
PREDICT_NAME = curl 'http://$JOLIGEN_SERVER/predict' -X POST \
--data-raw \
'{
"predict_options": {
"model_in_file": "$PATH_TO_MODEL/latest_net_G_A.pth",
"img_in": "$PATH_TO_FILES/source.jpg",
"img_out": "$PATH_TO_FILES/output.jpg"
}
}' | jq .name
WEBSOCKET_URL='http://$JOLIGEN_SERVER/ws/predict/$PREDICT_NAME'
curl -N -i \
-H "Connection: Upgrade" \
-H "Upgrade: websocket"
$WEBSOCKET_URL | jq .
Websocket message will be returned by api server. Websocket connection will be closed when the inference is finished or if an error has been encountered

****************************
Curl: inference on Diffusion
****************************

Using the following `curl` command, it will create `$PATH_TO_FILES/output.jpg` file and return it inside `curl_response["base64"]` in a base64 formatted string.

Base64 string will be saved into `output.jpg` file.

.. code:: bash
PATH_TO_MODEL=/home/joligen/models/horse2zebra/
PATH_TO_FILES=/home/joligen/files/
JOLIGEN_SERVER=localhost:18100
BASE64_IMG = curl 'http://$JOLIGEN_SERVER/predict' -X POST \
--data-raw \
'{
"predict_options": {
"model_in_file": "$PATH_TO_MODEL/latest_net_G_A.pth",
"img_in": "$PATH_TO_FILES/source.jpg",
"dir_out": "$PATH_TO_FILES"
},
"server": {
"sync": true,
"base64": true
}
}' | jq .base64[0]
cat $BASE64_IMG | base64 -d > output.jpg
If `payload["server"]["base64"]` is not enabled, file will be created on disk but it won't be returned inside `curl_response`.

If `payload["server"]["sync"]` is not enabled, the inference process will run in an asynchronous mode, `curl_response` will only return a message status stating that the process has started.

In `async` mode, process status can be followed using websocket:

.. code:: bash
JOLIGEN_SERVER=localhost:18100
PREDICT_NAME = curl 'http://$JOLIGEN_SERVER/predict' -X POST \
--data-raw \
'{
"predict_options": {
"model_in_file": "$PATH_TO_MODEL/latest_net_G_A.pth",
"img_in": "$PATH_TO_FILES/source.jpg",
"dir_out": "$PATH_TO_FILES"
}
}' | jq .name
WEBSOCKET_URL='http://$JOLIGEN_SERVER/ws/predict/$PREDICT_NAME'
curl -N -i \
-H "Connection: Upgrade" \
-H "Upgrade: websocket"
$WEBSOCKET_URL | jq .
Websocket message will be returned by api server. Websocket connection will be closed when the inference is finished or if an error has been encountered.
3 changes: 2 additions & 1 deletion options/inference_diffusion_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ def initialize(self, parser):
"""Define the common options that are used in both training and test."""
parser = super().initialize(parser)

parser.add_argument("--name", help="inference process name", default="predict")

parser.add_argument(
"--model_in_file",
help="file path to generator model (.pth file)",
Expand All @@ -19,7 +21,6 @@ def initialize(self, parser):
help="The directory where to output result images",
required=True,
)
parser.add_argument("--name", help="generated img name", default="img")

parser.add_argument(
"--img_width",
Expand Down
2 changes: 2 additions & 0 deletions options/inference_gan_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ def initialize(self, parser):
"""Define the common options that are used in both training and test."""
parser = super().initialize(parser)

parser.add_argument("--name", help="inference process name", default="predict")

parser.add_argument(
"--model_in_file",
help="file path to generator model (.pth file)",
Expand Down
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
torch==2.1 -f https://download.pytorch.org/whl/cu117/torch_stable.html
torchvision==0.16
numpy==1.23.1
pluggy==1.3.0
Pillow
timm
visdom
Expand All @@ -23,3 +24,6 @@ piq
git+https://github.com/ChaoningZhang/MobileSAM.git
Ninja
lpips
pytest
pytest-asyncio
httpx
38 changes: 37 additions & 1 deletion scripts/gen_single_image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import sys
import logging

jg_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../")
sys.path.append(jg_dir)
Expand Down Expand Up @@ -48,14 +49,41 @@ def load_model(modelpath, model_in_file, cpu, gpuid):
return model, opt, device


def inference_logger(name):

PROCESS_NAME = "gen_single_image"
LOG_PATH = os.environ.get(
"LOG_PATH", os.path.join(os.path.dirname(__file__), "../logs")
)
if not os.path.exists(LOG_PATH):
os.makedirs(LOG_PATH)

logging.basicConfig(
level=logging.DEBUG,
handlers=[
logging.FileHandler(f"{LOG_PATH}/{name}.log", mode="w"),
logging.StreamHandler(),
],
)

return logging.getLogger(f"inference %s %s" % (PROCESS_NAME, name))


def inference(args):
modelpath = args.model_in_file.replace(os.path.basename(args.model_in_file), "")

PROGRESS_NUM_STEPS = 6
logger = inference_logger(args.name)
logger.info(f"[1/%i] launch inference" % PROGRESS_NUM_STEPS)

modelpath = os.path.dirname(args.model_in_file)
print("modelpath=%s" % modelpath)

model, opt, device = load_model(
modelpath, os.path.basename(args.model_in_file), args.cpu, args.gpuid
)

logger.info(f"[2/%i] model loaded" % PROGRESS_NUM_STEPS)

# reading image
img_width = args.img_width if args.img_width is not None else opt.data_crop_size
img_height = args.img_height if args.img_height is not None else opt.data_crop_size
Expand All @@ -64,6 +92,8 @@ def inference(args):
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (img_width, img_height), interpolation=cv2.INTER_CUBIC)

logger.info(f"[3/%i] image loaded" % PROGRESS_NUM_STEPS)

# preprocessing
tranlist = [
transforms.ToTensor(),
Expand All @@ -88,9 +118,13 @@ def inference(args):
else:
img_tensor = img_tensor.unsqueeze(0)

logger.info(f"[4/%i] preprocessing finished" % PROGRESS_NUM_STEPS)

# run through model
out_tensor = model(img_tensor)[0].detach()

logger.info(f"[5/%i] out tensor available" % PROGRESS_NUM_STEPS)

# post-processing
out_img = out_tensor.data.cpu().float().numpy()
print(out_img.shape)
Expand All @@ -102,6 +136,8 @@ def inference(args):
out_img = np.concatenate((original_img, out_img), axis=1)

cv2.imwrite(args.img_out, out_img)

logger.info(f"[6/%i] success - %s" % (PROGRESS_NUM_STEPS, args.img_out))
print("Successfully generated image ", args.img_out)


Expand Down
60 changes: 60 additions & 0 deletions scripts/gen_single_image_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import sys
import tempfile
import warnings
import logging

import cv2
import numpy as np
Expand Down Expand Up @@ -193,8 +194,13 @@ def generate(
alg_palette_ddim_num_steps,
alg_palette_ddim_eta,
model_prior_321_backwardcompatibility,
logger,
iteration,
nb_samples,
**unused_options,
):

PROGRESS_NUM_STEPS = 4
# seed
if seed >= 0:
torch.manual_seed(seed)
Expand Down Expand Up @@ -223,6 +229,12 @@ def generate(
if alg_diffusion_cond_image_creation is not None:
opt.alg_diffusion_cond_image_creation = alg_diffusion_cond_image_creation

if logger:
logger.info(
f"[it: %i/%i] - [1/%i] model loaded"
% (iteration, nb_samples, PROGRESS_NUM_STEPS)
)

conditioning = opt.alg_diffusion_cond_embed

for i, delta_values in enumerate(mask_delta):
Expand Down Expand Up @@ -381,6 +393,12 @@ def generate(
if ref is not None:
ref = cv2.resize(ref, (img_width, img_height))

if logger:
logger.info(
f"[it: %i/%i] - [2/%i] image loaded"
% (iteration, nb_samples, PROGRESS_NUM_STEPS)
)

# insert cond image into original image
generated_bbox = None
if cond_in:
Expand Down Expand Up @@ -620,6 +638,12 @@ def generate(
out_tensor
) # out_img = out_img.detach().data.cpu().float().numpy()[0]

if logger:
logger.info(
f"[it: %i/%i] - [3/%i] processing completed"
% (iteration, nb_samples, PROGRESS_NUM_STEPS)
)

""" post-processing
out_img = (np.transpose(out_img, (1, 2, 0)) + 1) / 2.0 * 255.0
Expand Down Expand Up @@ -675,10 +699,42 @@ def generate(

print("Successfully generated image ", name)

if logger:
logger.info(
f"[it: %i/%i] - [4/%i] image written"
% (iteration, nb_samples, PROGRESS_NUM_STEPS)
)

return out_img_real_size, model, opt


def inference_logger(name):

PROCESS_NAME = "gen_single_image_diffusion"
LOG_PATH = os.environ.get(
"LOG_PATH", os.path.join(os.path.dirname(__file__), "../logs")
)
if not os.path.exists(LOG_PATH):
os.makedirs(LOG_PATH)

logging.basicConfig(
level=logging.DEBUG,
handlers=[
logging.FileHandler(f"{LOG_PATH}/{name}.log", mode="w"),
logging.StreamHandler(),
],
)

return logging.getLogger(f"inference %s %s" % (PROCESS_NAME, name))


def inference(args):

PROGRESS_NUM_STEPS = 6
logger = inference_logger(args.name)

args.logger = logger

if len(args.mask_delta_ratio[0]) == 1 and args.mask_delta_ratio[0][0] == 0.0:
mask_delta = args.mask_delta
else:
Expand All @@ -693,11 +749,15 @@ def inference(args):
args.lopt = None

for i in tqdm(range(args.nb_samples)):
args.iteration = i + 1
logger.info(f"[it: %i/%i] launch inference" % (args.iteration, args.nb_samples))
args.name = real_name + "_" + str(i).zfill(len(str(args.nb_samples)))
frame, lmodel, lopt = generate(**vars(args))
args.lmodel = lmodel
args.lopt = lopt

logger.info(f"success - end of inference")


if __name__ == "__main__":
args = InferenceDiffusionOptions().parse()
Expand Down
Loading

0 comments on commit 6fd43d8

Please sign in to comment.