From 420ff3a3aa56d424c36594be20824af94f46d5da Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Wed, 28 Apr 2021 16:05:14 +0200 Subject: [PATCH] New Colors() class (#2963) --- detect.py | 9 +++------ models/common.py | 5 ++--- utils/plots.py | 24 ++++++++++++++++-------- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/detect.py b/detect.py index f5e53d991504..ba42f349dbaf 100644 --- a/detect.py +++ b/detect.py @@ -11,7 +11,7 @@ from utils.datasets import LoadStreams, LoadImages from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \ scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box -from utils.plots import plot_one_box +from utils.plots import colors, plot_one_box from utils.torch_utils import select_device, load_classifier, time_synchronized @@ -34,6 +34,7 @@ def detect(opt): model = attempt_load(weights, map_location=device) # load FP32 model stride = int(model.stride.max()) # model stride imgsz = check_img_size(imgsz, s=stride) # check img_size + names = model.module.names if hasattr(model, 'module') else model.names # get class names if half: model.half() # to FP16 @@ -52,10 +53,6 @@ def detect(opt): else: dataset = LoadImages(source, img_size=imgsz, stride=stride) - # Get names and colors - names = model.module.names if hasattr(model, 'module') else model.names - colors = [[random.randint(0, 255) for _ in range(3)] for _ in names] - # Run inference if device.type != 'cpu': model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once @@ -112,7 +109,7 @@ def detect(opt): c = int(cls) # integer class label = None if opt.hide_labels else (names[c] if opt.hide_conf else f'{names[c]} {conf:.2f}') - plot_one_box(xyxy, im0, label=label, color=colors[c], line_thickness=opt.line_thickness) + plot_one_box(xyxy, im0, label=label, color=colors(c, True), line_thickness=opt.line_thickness) if opt.save_crop: save_one_box(xyxy, im0s, file=save_dir / 'crops' / names[c] / f'{p.stem}.jpg', BGR=True) diff --git a/models/common.py b/models/common.py index a28621904b0e..9764d4c3a6c0 100644 --- a/models/common.py +++ b/models/common.py @@ -14,7 +14,7 @@ from utils.datasets import letterbox from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh, save_one_box -from utils.plots import color_list, plot_one_box +from utils.plots import colors, plot_one_box from utils.torch_utils import time_synchronized @@ -312,7 +312,6 @@ def __init__(self, imgs, pred, files, times=None, names=None, shape=None): self.s = shape # inference BCHW shape def display(self, pprint=False, show=False, save=False, crop=False, render=False, save_dir=Path('')): - colors = color_list() for i, (im, pred) in enumerate(zip(self.imgs, self.pred)): str = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' if pred is not None: @@ -325,7 +324,7 @@ def display(self, pprint=False, show=False, save=False, crop=False, render=False if crop: save_one_box(box, im, file=save_dir / 'crops' / self.names[int(cls)] / self.files[i]) else: # all others - plot_one_box(box, im, label=label, color=colors[int(cls) % 10]) + plot_one_box(box, im, label=label, color=colors(cls)) im = Image.fromarray(im.astype(np.uint8)) if isinstance(im, np.ndarray) else im # from np if pprint: diff --git a/utils/plots.py b/utils/plots.py index f24513c6998d..ab6448aa96eb 100644 --- a/utils/plots.py +++ b/utils/plots.py @@ -26,12 +26,22 @@ matplotlib.use('Agg') # for writing to files only -def color_list(): - # Return first 10 plt colors as (r,g,b) https://stackoverflow.com/questions/51350872/python-from-color-name-to-rgb - def hex2rgb(h): +class Colors: + # Ultralytics color palette https://ultralytics.com/ + def __init__(self): + self.palette = [self.hex2rgb(c) for c in matplotlib.colors.TABLEAU_COLORS.values()] + self.n = len(self.palette) + + def __call__(self, i, bgr=False): + c = self.palette[int(i) % self.n] + return (c[2], c[1], c[0]) if bgr else c + + @staticmethod + def hex2rgb(h): # rgb order (PIL) return tuple(int(h[1 + i:1 + i + 2], 16) for i in (0, 2, 4)) - return [hex2rgb(h) for h in matplotlib.colors.TABLEAU_COLORS.values()] # or BASE_ (8), CSS4_ (148), XKCD_ (949) + +colors = Colors() # create instance for 'from utils.plots import colors' def hist2d(x, y, n=100): @@ -137,7 +147,6 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max h = math.ceil(scale_factor * h) w = math.ceil(scale_factor * w) - colors = color_list() # list of colors mosaic = np.full((int(ns * h), int(ns * w), 3), 255, dtype=np.uint8) # init for i, img in enumerate(images): if i == max_subplots: # if last batch has fewer images than we expect @@ -168,7 +177,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max boxes[[1, 3]] += block_y for j, box in enumerate(boxes.T): cls = int(classes[j]) - color = colors[cls % len(colors)] + color = colors(cls) cls = names[cls] if names else cls if labels or conf[j] > 0.25: # 0.25 conf thresh label = '%s' % cls if labels else '%s %.1f' % (cls, conf[j]) @@ -276,7 +285,6 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): print('Plotting labels... ') c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes nc = int(c.max() + 1) # number of classes - colors = color_list() x = pd.DataFrame(b.transpose(), columns=['x', 'y', 'width', 'height']) # seaborn correlogram @@ -302,7 +310,7 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None): labels[:, 1:] = xywh2xyxy(labels[:, 1:]) * 2000 img = Image.fromarray(np.ones((2000, 2000, 3), dtype=np.uint8) * 255) for cls, *box in labels[:1000]: - ImageDraw.Draw(img).rectangle(box, width=1, outline=colors[int(cls) % 10]) # plot + ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot ax[1].imshow(img) ax[1].axis('off')