-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_net.py
178 lines (146 loc) · 5.83 KB
/
test_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
#!/usr/bin/env python
# --------------------------------------------------------------------------------
# MPViT: Multi-Path Vision Transformer for Dense Prediction
# Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
# All Rights Reserved.
# Written by Youngwan Lee
# --------------------------------------------------------------------------------
"""
Detection Training Script for MPViT.
"""
import os
import itertools
import cv2
import torch
import numpy as np
import json
import io
from typing import Any, Dict, List, Set
from pdf2image import convert_from_path
from detectron2.data import build_detection_train_loader
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
from detectron2.evaluation import COCOEvaluator
from detectron2.evaluation.coco_evaluation import instances_to_coco_json
from detectron2.solver.build import maybe_add_gradient_clipping
from ditod import add_vit_config
from ditod import DetrDatasetMapper
from detectron2.data.datasets import register_coco_instances
import logging
from detectron2.utils.logger import setup_logger
from detectron2.utils import comm
from detectron2.engine.defaults import create_ddp_model
import weakref
from detectron2.engine.train_loop import AMPTrainer, SimpleTrainer
from ditod import MyDetectionCheckpointer, ICDAREvaluator
from ditod import MyTrainer, DefaultPredictor, CustomPredictor
import pycocotools.mask as mask_util
from detectron2.structures import Boxes, BoxMode
def setup(args):
"""
Create configs and perform basic setups.
"""
cfg = get_cfg()
# add_coat_config(cfg)
add_vit_config(cfg)
cfg.merge_from_file(args.config_file)
cfg.merge_from_list(args.opts)
cfg.freeze()
default_setup(cfg, args)
return cfg
def custom_instance_to_coco(instances, img_id):
"""
Dump an "Instances" object to a COCO-format json that's used for evaluation.
Args:
instances (Instances):
img_id (int): the image id
Returns:
list[dict]: list of json annotations in COCO format.
"""
num_instance = len(instances)
if num_instance == 0:
return []
# boxes = instances.pred_boxes.tensor.numpy()
boxes = instances.pred_boxes.to("cpu").tensor.numpy()
boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
boxes = boxes.tolist()
scores = instances.scores.tolist()
classes = instances.pred_classes.tolist()
has_mask = instances.has("pred_masks")
if has_mask:
# print(instances.pred_masks)
# use RLE to encode the masks, because they are too large and takes memory
# since this evaluator stores outputs of the entire dataset
rles = [
mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
for mask in instances.pred_masks.to('cpu')
]
for rle in rles:
# "counts" is an array encoded by mask_util as a byte-stream. Python3's
# json writer which always produces strings cannot serialize a bytestream
# unless you decode it. Thankfully, utf-8 works out (which is also what
# the pycocotools/_mask.pyx does).
rle["counts"] = rle["counts"].decode("utf-8")
results = []
for k in range(num_instance):
result = {
"image_id": img_id,
"category_id": classes[k],
"bbox": boxes[k],
"score": scores[k],
}
if has_mask:
result["segmentation"] = rles[k]
results.append(result)
return results
def list_files(in_path):
img_files = []
for (dirpath, dirnames, filenames) in os.walk(in_path):
for file in filenames:
filename, ext = os.path.splitext(file)
ext = str.lower(ext)
if ext == '.pdf':
pdf_img = convert_from_path(in_path)
pdf_img.save(in_path+filename+'.jpg', 'JPEG')
if ext == '.jpg' or ext == '.jpeg' or ext == '.gif' or ext == '.png' or ext == '.pgm':
img_files.append(os.path.join(dirpath, file))
return img_files
def main(args):
cfg = setup(args)
final_result = list()
img_list = list_files(args.imagePath)
if args.imagePath is not None:
img_list = list_files(args.imagePath)
predict = CustomPredictor(cfg)
with io.open(os.path.join(cfg.OUTPUT_DIR, 'coco_instances_results.json'), 'w') as db_file:
for idx, image in enumerate(img_list):
img_id = os.path.basename(image).split('/')[-1]
print("Test image {:d}/{:d}: {:s}".format(idx+1, len(img_list), img_id))
output = predict(image)
res = custom_instance_to_coco(output, img_id)
final_result = final_result + res
del outp
db_file.write(json.dumps(final_result))
else:
print("Please input the path to the input image")
if __name__ == "__main__":
parser = default_argument_parser()
parser.add_argument("--debug", action="store_true", help="enable debug mode")
parser.add_argument("--imagePath", default = None, help="test on the new data")
args = parser.parse_args()
print("Command Line Args:", args)
if args.debug:
import debugpy
print("Enabling attach starts.")
debugpy.listen(address=('0.0.0.0', 9310))
debugpy.wait_for_client()
print("Enabling attach ends.")
launch(
main,
args.num_gpus,
num_machines=args.num_machines,
machine_rank=args.machine_rank,
dist_url=args.dist_url,
args=(args,),
)