forked from microsoft/OmniParser
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclient.py
118 lines (97 loc) · 3.78 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
This module provides a command-line interface to interact with the OmniParser Gradio server.
Usage:
python client.py "http://<server_ip>:7861" "path/to/image.jpg"
View results:
JSON: cat result_data_<timestamp>.json
Image:
macOS: open output_image_<timestamp>.png
Windows: start output_image_<timestamp>.png
Linux: xdg-open output_image_<timestamp>.png
Result data format:
{
"label_coordinates": {
"0": [x1, y1, width, height], // Normalized coordinates for each bounding box
"1": [x1, y1, width, height],
...
},
"parsed_content_list": [
"Text Box ID 0: [content]",
"Text Box ID 1: [content]",
...,
"Icon Box ID X: [description]",
...
]
}
Note: The parsed_content_list includes both text box contents and icon descriptions.
"""
import fire
from gradio_client import Client
from loguru import logger
from PIL import Image
import base64
from io import BytesIO
import os
import shutil
import json
from datetime import datetime
def predict(server_url: str, image_path: str, box_threshold: float = 0.05, iou_threshold: float = 0.1):
"""
Makes a prediction using the OmniParser Gradio client with the provided server URL and image.
Args:
server_url (str): The URL of the OmniParser Gradio server.
image_path (str): Path to the image file to be processed.
box_threshold (float): Box threshold value (default: 0.05).
iou_threshold (float): IOU threshold value (default: 0.1).
"""
client = Client(server_url)
# Generate a timestamp for unique file naming
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Load and encode the image
with open(image_path, "rb") as image_file:
encoded_image = base64.b64encode(image_file.read()).decode("utf-8")
# Prepare the image input in the format expected by the server
image_input = {
"path": None,
"url": f"data:image/png;base64,{encoded_image}",
"size": None,
"orig_name": image_path,
"mime_type": "image/png",
"is_stream": False,
"meta": {}
}
# Make the prediction
try:
result = client.predict(
image_input, # image input as dictionary
box_threshold, # box_threshold
iou_threshold, # iou_threshold
api_name="/process"
)
# Process and log the results
output_image, result_json = result
logger.info("Prediction completed successfully")
# Parse the JSON string into a Python object
result_data = json.loads(result_json)
# Extract label_coordinates and parsed_content_list
label_coordinates = result_data['label_coordinates']
parsed_content_list = result_data['parsed_content_list']
logger.info(f"{label_coordinates=}")
logger.info(f"{parsed_content_list=}")
# Save result data to JSON file
result_data_path = f"result_data_{timestamp}.json"
with open(result_data_path, "w") as json_file:
json.dump(result_data, json_file, indent=4)
logger.info(f"Parsed content saved to: {result_data_path}")
# Save the output image
output_image_path = f"output_image_{timestamp}.png"
if isinstance(output_image, str) and os.path.exists(output_image):
shutil.copy(output_image, output_image_path)
logger.info(f"Output image saved to: {output_image_path}")
else:
logger.warning(f"Unexpected output_image format or file not found: {output_image}")
except Exception as e:
logger.error(f"An error occurred: {str(e)}")
logger.exception("Traceback:")
if __name__ == "__main__":
fire.Fire(predict)