Skip to content

Commit

Permalink
COCO eval supports custom datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
leondgarse committed Oct 22, 2022
1 parent 54f5224 commit 56c222a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 22 deletions.
21 changes: 10 additions & 11 deletions coco_train_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def parse_arguments(argv):
ds_group.add_argument("--rescale_mode", type=str, default="torch", help="Rescale mode, one of [tf, torch, raw, raw01]")
ds_group.add_argument("--resize_method", type=str, default="bicubic", help="Resize method from tf.image.resize, like [bilinear, bicubic]")
ds_group.add_argument("--disable_antialias", action="store_true", help="Set use antialias=False for tf.image.resize")
ds_group.add_argument("--max_labels_per_image", type=int, default=100, help="Max number of ground truth labels used in a single image")

args = parser.parse_known_args(argv)[0]

Expand Down Expand Up @@ -193,6 +194,7 @@ def run_training_by_args(args):
data_name=args.data_name,
input_shape=input_shape,
batch_size=batch_size,
max_labels_per_image=args.max_labels_per_image,
anchors_mode=args.anchors_mode,
anchor_pyramid_levels=args.anchor_pyramid_levels,
anchor_scale=args.anchor_scale,
Expand Down Expand Up @@ -235,17 +237,14 @@ def run_training_by_args(args):
print(">>>> basic_save_name =", args.basic_save_name)
# return None, None, None

if args.data_name == "coco/2017":
# Save line width...
kw = {"batch_size": batch_size, "rescale_mode": args.rescale_mode, "resize_method": args.resize_method, "resize_antialias": resize_antialias}
kw.update({"anchor_scale": args.anchor_scale, "anchors_mode": args.anchors_mode, "model_basic_save_name": args.basic_save_name})
start_epoch, frequency = epochs * 2 // 3, 1 # coco eval starts from 2/3 epochs
coco_ap_eval = eval_func.COCOEvalCallback("coco/2017", start_epoch=start_epoch, frequency=frequency, **kw)
init_callbacks = [coco_ap_eval]
test_dataset = None # COCO eval using coco_ap_eval callback, set `validation_data` for `model.fit` to None
print(">>>> COCO AP eval start_epoch: {}, frequency: {}".format(start_epoch, frequency))
else:
init_callbacks = []
# Save line width...
kw = {"batch_size": batch_size, "rescale_mode": args.rescale_mode, "resize_method": args.resize_method, "resize_antialias": resize_antialias}
kw.update({"anchor_scale": args.anchor_scale, "anchors_mode": args.anchors_mode, "model_basic_save_name": args.basic_save_name})
start_epoch, frequency = epochs * 2 // 3, 1 # coco eval starts from 2/3 epochs
coco_ap_eval = eval_func.COCOEvalCallback(args.data_name, start_epoch=start_epoch, frequency=frequency, **kw)
init_callbacks = [coco_ap_eval]
test_dataset = None # COCO eval using coco_ap_eval callback, set `validation_data` for `model.fit` to None
print(">>>> COCO AP eval start_epoch: {}, frequency: {}".format(start_epoch, frequency))
latest_save, hist = train(
model, epochs, train_dataset, test_dataset, args.initial_epoch, lr_scheduler, args.basic_save_name, init_callbacks, logs=args.tensorboard_logs
)
Expand Down
72 changes: 61 additions & 11 deletions keras_cv_attention_models/coco/eval_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,15 @@ def scale_bboxes_back_single(bboxes, image_shape, scale, pad_top, pad_left, targ
def image_process(image, target_shape, mean, std, resize_method="bilinear", resize_antialias=False, use_bgr_input=False, letterbox_pad=-1):
if len(image.shape) < 2:
image = data.tf_imread(image) # it's image path
original_image_shape = tf.shape(image)[:2]
image = tf.cast(image, "float32")
image = (image - mean) / std # automl behavior: rescale -> resize
image, scale, pad_top, pad_left = data.aspect_aware_resize_and_crop_image(
image, target_shape, letterbox_pad=letterbox_pad, method=resize_method, antialias=resize_antialias
)
if use_bgr_input:
image = image[:, :, ::-1]
return image, scale, pad_top, pad_left
return image, scale, pad_top, pad_left, original_image_shape


def init_eval_dataset(
Expand All @@ -172,15 +173,18 @@ def init_eval_dataset(
mean, std = data.init_mean_std_by_rescale_mode(rescale_mode)
__image_process__ = lambda image: image_process(image, input_shape, mean, std, resize_method, resize_antialias, use_bgr_input, letterbox_pad)
# ds: [resized_image, scale, pad_top, pad_left, original_image_shape, image_id]
ds = ds.map(lambda datapoint: (*__image_process__(datapoint["image"]), tf.shape(datapoint["image"])[:2], datapoint["image/id"]))
ds = ds.map(lambda datapoint: (*__image_process__(datapoint["image"]), datapoint.get("image/id", datapoint["image"])))
ds = ds.batch(batch_size)
return ds


def model_detection_and_decode(model, eval_dataset, pred_decoder, nms_kwargs={}):
def model_detection_and_decode(model, eval_dataset, pred_decoder, nms_kwargs={}, is_coco=True, image_id_map=None):
target_shape = (eval_dataset.element_spec[0].shape[1], eval_dataset.element_spec[0].shape[2])
num_classes = model.output_shape[-1] - 4
to_91_labels = (lambda label: label + 1) if num_classes >= 90 else (lambda label: data.COCO_80_to_90_LABEL_DICT[label] + 1)
if is_coco:
to_91_labels = (lambda label: label + 1) if num_classes >= 90 else (lambda label: data.COCO_80_to_90_LABEL_DICT[label] + 1)
else:
to_91_labels = lambda label: label
# Format: [image_id, x, y, width, height, score, class]
to_coco_eval_single = lambda image_id, bbox, label, score: [image_id, *bbox.tolist(), score, to_91_labels(label)]

Expand All @@ -195,20 +199,28 @@ def model_detection_and_decode(model, eval_dataset, pred_decoder, nms_kwargs={})
for rr, image_shape, scale, pad_top, pad_left, image_id in zip(decoded_preds, original_image_shapes, scales, pad_tops, pad_lefts, image_ids):
bboxes, labels, scores = rr
image_id = image_id.numpy()
if image_id_map is not None:
image_id = image_id_map[image_id.decode() if isinstance(image_id, bytes) else image_id]
bboxes = scale_bboxes_back_single(bboxes, image_shape, scale, pad_top, pad_left, target_shape).numpy()
results.extend([to_coco_eval_single(image_id, bb, cc, ss) for bb, cc, ss in zip(bboxes, labels, scores)]) # Loop on prediction results
return tf.convert_to_tensor(results).numpy()


def coco_evaluation(detection_results, annotation_file=None):
def coco_evaluation(detection_results, annotations=None):
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

if annotation_file is None:
if annotations is None:
url = "https://github.com/leondgarse/keras_cv_attention_models/releases/download/efficientdet/coco_annotations_instances_val2017.json"
annotation_file = tf.keras.utils.get_file(origin=url)
coco_gt = COCO(annotation_file)
image_ids = [ii["image_id"] for ii in detection_results] if isinstance(detection_results[0], dict) else detection_results[:, 0]
annotations = tf.keras.utils.get_file(origin=url)

if isinstance(annotations, dict): # json already loaded as dict
coco_gt = COCO()
coco_gt.dataset = annotations
coco_gt.createIndex()
else:
coco_gt = COCO(annotations)
image_ids = [ii["image_id"] for ii in detection_results] if isinstance(detection_results[0], dict) else [ii[0] for ii in detection_results]
image_ids = list(set(image_ids))
print("len(image_ids) =", len(image_ids))
coco_dt = coco_gt.loadRes(detection_results)
Expand All @@ -223,12 +235,43 @@ def coco_evaluation(detection_results, annotation_file=None):
def to_coco_json(detection_results, save_path, indent=2):
import json

__to_coco_json__ = lambda xx: {"image_id": int(xx[0]), "bbox": xx[1:5].tolist(), "score": float(xx[5]), "category_id": int(xx[6])}
__to_coco_json__ = lambda xx: {"image_id": int(xx[0]), "bbox": [float(ii) for ii in xx[1:5]], "score": float(xx[5]), "category_id": int(xx[6])}
aa = [__to_coco_json__(ii) for ii in detection_results]
with open(save_path, "w") as ff:
json.dump(aa, ff, indent=indent)


def to_coco_annotation(json_path):
import json
from PIL import Image

with open(json_path, "r") as ff:
aa = json.load(ff)

# int conversion just in case key is str
categories = {int(kk): vv for kk, vv in aa["indices_2_labels"].items()} if "indices_2_labels" in aa else {}
annotations, images, image_id_map = [], [], {}
for image_id, ii in enumerate(aa.get("validation", aa.get("test", []))):
width, height = Image.open(ii["image"]).size # For decoding bboxes, not actually openning images
for bb, label in zip(ii["objects"]["bbox"], ii["objects"]["label"]):
# bb [top, left, bottom, right] in [0, 1] -> [left, top, bbox_width, bbox_height] with actual coordinates
top = bb[0] * height
left = bb[1] * width
bbox_height = bb[2] * height - top
bbox_width = bb[3] * width - left
bb = [left, top, bbox_width, bbox_height]
area = bbox_width * bbox_height # Actual area in COCO is the segmentation area, doesn't matter in detection mission

label = int(label)
annotations.append({"bbox": bb, "category_id": label, "image_id": image_id, "id": len(annotations), "iscrowd": 0, "area": area})
if label not in categories:
categories[label] = str(len(categories))
images.append({"id": image_id, "file_name": ii["image"], "height": height, "width": width})
image_id_map[ii["image"]] = image_id
categories = [{"id": kk, "name": vv} for kk, vv in categories.items()]
return {"images": images, "annotations": annotations, "categories": categories}, image_id_map


# Wrapper a callback for using in training
class COCOEvalCallback(tf.keras.callbacks.Callback):
"""
Expand Down Expand Up @@ -287,6 +330,13 @@ def __init__(
"mode": nms_mode,
"topk": nms_topk,
}

self.is_coco = True if data_name.startswith("coco") and not data_name.endswith(".json") else False
if self.data_name.endswith(".json") and self.annotation_file is None:
self.annotation_file, self.image_id_map = to_coco_annotation(self.data_name)
else:
self.image_id_map = None

self.built = False

def build(self, input_shape, output_shape):
Expand Down Expand Up @@ -322,7 +372,7 @@ def on_epoch_end(self, epoch=0, logs=None):

# pred_decoder = self.model.decode_predictions
eval_dataset = self.eval_dataset.take(self.take_samples) if self.take_samples > 0 else self.eval_dataset
detection_results = model_detection_and_decode(self.model, eval_dataset, self.pred_decoder, self.nms_kwargs)
detection_results = model_detection_and_decode(self.model, eval_dataset, self.pred_decoder, self.nms_kwargs, self.is_coco, self.image_id_map)
try:
coco_eval = coco_evaluation(detection_results, self.annotation_file)
except:
Expand Down

0 comments on commit 56c222a

Please sign in to comment.