Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleaned-up coco evaluation code #4453

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 15 additions & 172 deletions references/detection/coco_eval.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
import json
import copy
import io
from contextlib import redirect_stdout

import numpy as np
import copy
import pycocotools.mask as mask_util
import torch
import torch._six

from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO
import pycocotools.mask as mask_util

from collections import defaultdict

import utils


class CocoEvaluator(object):
class CocoEvaluator:
def __init__(self, coco_gt, iou_types):
assert isinstance(iou_types, (list, tuple))
coco_gt = copy.deepcopy(coco_gt)
Expand All @@ -34,7 +31,8 @@ def update(self, predictions):

for iou_type in self.iou_types:
results = self.prepare(predictions, iou_type)
coco_dt = loadRes(self.coco_gt, results) if results else COCO()
with redirect_stdout(io.StringIO()):
coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
coco_eval = self.coco_eval[iou_type]

coco_eval.cocoDt = coco_dt
Expand All @@ -54,18 +52,17 @@ def accumulate(self):

def summarize(self):
for iou_type, coco_eval in self.coco_eval.items():
print("IoU metric: {}".format(iou_type))
print(f"IoU metric: {iou_type}")
coco_eval.summarize()

def prepare(self, predictions, iou_type):
if iou_type == "bbox":
return self.prepare_for_coco_detection(predictions)
elif iou_type == "segm":
if iou_type == "segm":
return self.prepare_for_coco_segmentation(predictions)
elif iou_type == "keypoints":
if iou_type == "keypoints":
return self.prepare_for_coco_keypoint(predictions)
else:
raise ValueError("Unknown iou type {}".format(iou_type))
raise ValueError(f"Unknown iou type {iou_type}")

def prepare_for_coco_detection(self, predictions):
coco_results = []
Expand Down Expand Up @@ -190,161 +187,7 @@ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
coco_eval._paramsEval = copy.deepcopy(coco_eval.params)


#################################################################
# From pycocotools, just removed the prints and fixed
# a Python3 bug about unicode not defined
#################################################################

# Ideally, pycocotools wouldn't have hard-coded prints
# so that we could avoid copy-pasting those two functions

def createIndex(self):
# create index
# print('creating index...')
anns, cats, imgs = {}, {}, {}
imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
if 'annotations' in self.dataset:
for ann in self.dataset['annotations']:
imgToAnns[ann['image_id']].append(ann)
anns[ann['id']] = ann

if 'images' in self.dataset:
for img in self.dataset['images']:
imgs[img['id']] = img

if 'categories' in self.dataset:
for cat in self.dataset['categories']:
cats[cat['id']] = cat

if 'annotations' in self.dataset and 'categories' in self.dataset:
for ann in self.dataset['annotations']:
catToImgs[ann['category_id']].append(ann['image_id'])

# print('index created!')

# create class members
self.anns = anns
self.imgToAnns = imgToAnns
self.catToImgs = catToImgs
self.imgs = imgs
self.cats = cats


maskUtils = mask_util


def loadRes(self, resFile):
"""
Load result file and return a result api object.
Args:
self (obj): coco object with ground truth annotations
resFile (str): file name of result file
Returns:
res (obj): result api object
"""
res = COCO()
res.dataset['images'] = [img for img in self.dataset['images']]

# print('Loading and preparing results...')
# tic = time.time()
if isinstance(resFile, torch._six.string_classes):
anns = json.load(open(resFile))
elif type(resFile) == np.ndarray:
anns = self.loadNumpyAnnotations(resFile)
else:
anns = resFile
assert type(anns) == list, 'results in not an array of objects'
annsImgIds = [ann['image_id'] for ann in anns]
assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
'Results do not correspond to current coco set'
if 'caption' in anns[0]:
imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
for id, ann in enumerate(anns):
ann['id'] = id + 1
elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
for id, ann in enumerate(anns):
bb = ann['bbox']
x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
if 'segmentation' not in ann:
ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
ann['area'] = bb[2] * bb[3]
ann['id'] = id + 1
ann['iscrowd'] = 0
elif 'segmentation' in anns[0]:
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
for id, ann in enumerate(anns):
# now only support compressed RLE format as segmentation results
ann['area'] = maskUtils.area(ann['segmentation'])
if 'bbox' not in ann:
ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
ann['id'] = id + 1
ann['iscrowd'] = 0
elif 'keypoints' in anns[0]:
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
for id, ann in enumerate(anns):
s = ann['keypoints']
x = s[0::3]
y = s[1::3]
x1, x2, y1, y2 = np.min(x), np.max(x), np.min(y), np.max(y)
ann['area'] = (x2 - x1) * (y2 - y1)
ann['id'] = id + 1
ann['bbox'] = [x1, y1, x2 - x1, y2 - y1]
# print('DONE (t={:0.2f}s)'.format(time.time()- tic))

res.dataset['annotations'] = anns
createIndex(res)
return res


def evaluate(self):
'''
Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
:return: None
'''
# tic = time.time()
# print('Running per image evaluation...')
p = self.params
# add backward compatibility if useSegm is specified in params
if p.useSegm is not None:
p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
# print('Evaluate annotation type *{}*'.format(p.iouType))
p.imgIds = list(np.unique(p.imgIds))
if p.useCats:
p.catIds = list(np.unique(p.catIds))
p.maxDets = sorted(p.maxDets)
self.params = p

self._prepare()
# loop through images, area range, max detection number
catIds = p.catIds if p.useCats else [-1]

if p.iouType == 'segm' or p.iouType == 'bbox':
computeIoU = self.computeIoU
elif p.iouType == 'keypoints':
computeIoU = self.computeOks
self.ious = {
(imgId, catId): computeIoU(imgId, catId)
for imgId in p.imgIds
for catId in catIds}

evaluateImg = self.evaluateImg
maxDet = p.maxDets[-1]
evalImgs = [
evaluateImg(imgId, catId, areaRng, maxDet)
for catId in catIds
for areaRng in p.areaRng
for imgId in p.imgIds
]
# this is NOT in the pycocotools code, but could be done outside
evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
self._paramsEval = copy.deepcopy(self.params)
# toc = time.time()
# print('DONE (t={:0.2f}s).'.format(toc-tic))
return p.imgIds, evalImgs

#################################################################
# end of straight copy from pycocotools, just removing the prints
#################################################################
def evaluate(imgs):
with redirect_stdout(io.StringIO()):
imgs.evaluate()
return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))