From 954fac7e2040368835aedd651ca9cea96d01e769 Mon Sep 17 00:00:00 2001 From: TelBotDev <77771760+TelBotDev@users.noreply.github.com> Date: Tue, 22 Jun 2021 21:40:49 +0800 Subject: [PATCH 1/5] Fixed: set default type for parser MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The value will be treated as `str` if there is no type. Then there will be error ``` Traceback (most recent call last):   File "demo/visualize_result.py", line 144, in     query_indices = visualizer.vis_rank_list(args.output, args.vis_label, args.num_vis, args.rank_sort, args.label_sort, args.max_rank)   File "./fastreid/utils/visualizer.py", line 158, in vis_rank_list     query_indices = query_indices[:num_vis] TypeError: slice indices must be integers or None or have an __index__ method ``` when we add some custom parameter, such as `--num-vis 5` --- demo/visualize_result.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/demo/visualize_result.py b/demo/visualize_result.py index 493b70202..6901addf6 100644 --- a/demo/visualize_result.py +++ b/demo/visualize_result.py @@ -72,6 +72,7 @@ def get_parser(): ) parser.add_argument( "--num-vis", + type=int, default=100, help="number of query images to be visualized", ) @@ -87,6 +88,7 @@ def get_parser(): ) parser.add_argument( "--max-rank", + type=int, default=10, help="maximum number of rank list to be visualized", ) From b495398f99ad823072d9f488077d8856f7ce3d17 Mon Sep 17 00:00:00 2001 From: TelBotDev <77771760+TelBotDev@users.noreply.github.com> Date: Wed, 23 Jun 2021 10:22:53 +0800 Subject: [PATCH 2/5] Update visualizer.py add actmap --- fastreid/utils/visualizer.py | 71 ++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 40 deletions(-) diff --git a/fastreid/utils/visualizer.py b/fastreid/utils/visualizer.py index 5a06abd85..531287833 100644 --- a/fastreid/utils/visualizer.py +++ b/fastreid/utils/visualizer.py @@ -7,6 +7,8 @@ import os import pickle import random +import cv2 +import torch.nn.functional as F import matplotlib.pyplot as plt import numpy as np @@ -23,7 +25,7 @@ class Visualizer: def __init__(self, dataset): self.dataset = dataset - def get_model_output(self, all_ap, dist, q_pids, g_pids, q_camids, g_camids): + def get_model_output(self, all_ap, dist, q_pids, g_pids, q_camids, g_camids, acts=None): self.all_ap = all_ap self.dist = dist self.sim = 1 - dist @@ -36,6 +38,8 @@ def get_model_output(self, all_ap, dist, q_pids, g_pids, q_camids, g_camids): self.matches = (g_pids[self.indices] == q_pids[:, np.newaxis]).astype(np.int32) self.num_query = len(q_pids) + + if acts: self.acts = acts def get_matched_result(self, q_index): q_pid = self.q_pids[q_index] @@ -65,7 +69,19 @@ def save_rank_result(self, query_indices, output, max_rank=5, vis_label=False, l query_img = np.rollaxis(np.asarray(query_img.numpy(), dtype=np.uint8), 0, 3) plt.clf() ax = fig.add_subplot(1, max_rank + 1, 1) - ax.imshow(query_img) + + # ax.imshow(query_img) + # added: show acts + if actmap: + query_acts = self.acts[q_idx] + overlapped = query_img*0.3 + query_acts*0.7 + overlapped[overlapped > 255] = 255 + overlapped = overlapped.astype(np.uint8) + ax.imshow(overlapped) + # added: show acts + else: + ax.imshow(query_img) + ax.set_title('{:.4f}/cam{}'.format(self.all_ap[q_idx], cam_id)) ax.axis("off") for i in range(max_rank): @@ -89,27 +105,21 @@ def save_rank_result(self, query_indices, output, max_rank=5, vis_label=False, l ax.add_patch(plt.Rectangle(xy=(0, 0), width=gallery_img.shape[1] - 1, height=gallery_img.shape[0] - 1, edgecolor=(0, 0, 1), fill=False, linewidth=5)) - ax.imshow(gallery_img) + + # added: show acts + if actmap: + gallery_acts = self.acts[g_idx] + overlapped = gallery_img*0.3 + gallery_acts*0.7 + overlapped[overlapped > 255] = 255 + overlapped = overlapped.astype(np.uint8) + ax.imshow(overlapped) + # added: show acts + else: + ax.imshow(gallery_img) + ax.set_title(f'{self.sim[q_idx, sort_idx[i]]:.3f}/{label}/cam{cam_id}') ax.axis("off") - # if actmap: - # act_outputs = [] - # - # def hook_fns_forward(module, input, output): - # act_outputs.append(output.cpu()) - # - # all_imgs = np.stack(all_imgs, axis=0) # (b, 3, h, w) - # all_imgs = torch.from_numpy(all_imgs).float() - # # normalize - # all_imgs = all_imgs.sub_(self.mean).div_(self.std) - # sz = list(all_imgs.shape[-2:]) - # handle = m.base.register_forward_hook(hook_fns_forward) - # with torch.no_grad(): - # _ = m(all_imgs.cuda()) - # handle.remove() - # acts = self.get_actmap(act_outputs[0], sz) - # for i in range(top + 1): - # axes.flat[i].imshow(acts[i], alpha=0.3, cmap='jet') + if vis_label: label_indice = np.where(cmc == 1)[0] if label_sort == "ascending": label_indice = label_indice[::-1] @@ -257,22 +267,3 @@ def load_roc_info(path): # plt.xticks(np.arange(0.1, 1.0, 0.1)) # plt.title('positive and negative pair distribution') # return fig - - # def get_actmap(self, features, sz): - # """ - # :param features: (1, 2048, 16, 8) activation map - # :return: - # """ - # features = (features ** 2).sum(1) # (1, 16, 8) - # b, h, w = features.size() - # features = features.view(b, h * w) - # features = nn.functional.normalize(features, p=2, dim=1) - # acts = features.view(b, h, w) - # all_acts = [] - # for i in range(b): - # act = acts[i].numpy() - # act = cv2.resize(act, (sz[1], sz[0])) - # act = 255 * (act - act.max()) / (act.max() - act.min() + 1e-12) - # act = np.uint8(np.floor(act)) - # all_acts.append(act) - # return all_acts From f1bd1c56e31a94ef9e7c19bb42e2a98fffcfa70c Mon Sep 17 00:00:00 2001 From: TelBotDev <77771760+TelBotDev@users.noreply.github.com> Date: Wed, 23 Jun 2021 10:23:57 +0800 Subject: [PATCH 3/5] Update predictor.py add actmap --- demo/predictor.py | 44 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/demo/predictor.py b/demo/predictor.py index c949fb397..ed7cd5375 100644 --- a/demo/predictor.py +++ b/demo/predictor.py @@ -9,6 +9,9 @@ from collections import deque import cv2 +import numpy as np +import torch.nn.functional as F + import torch import torch.multiprocessing as mp @@ -57,6 +60,31 @@ def run_on_image(self, original_image): predictions = self.predictor(image) return predictions + + def get_actmap(self, features, sz): + """ + :param features: (1, 2048, 16, 8) activation map + :return: + """ + features = (features ** 2).sum(1) # (1, 16, 8) + b, h, w = features.size() + features = features.view(b, h * w) + features = F.normalize(features, p=2, dim=1) + acts = features.view(b, h, w) + all_acts = [] + for i in range(b): + act = acts[i].numpy() + act = cv2.resize(act, (sz[1], sz[0])) + # act = 255 * (act - act.max()) / (act.max() - act.min() + 1e-12) + act = 255 * (act - act.min()) / (act.max() - act.min() + 1e-12) + + act = np.uint8(np.floor(act)) + act = cv2.applyColorMap(act, cv2.COLORMAP_JET) + + all_acts.append(act) + return all_acts + + def run_on_loader(self, data_loader): if self.parallel: buffer_size = self.predictor.default_buffer_size @@ -78,8 +106,22 @@ def run_on_loader(self, data_loader): yield predictions, batch["targets"].cpu().numpy(), batch["camids"].cpu().numpy() else: for batch in data_loader: + # add hook here to get features: start + act_outputs = [] + def hook_fns_forward(module, input, output): + act_outputs.append(output.cpu()) + handle = self.predictor.model.backbone.register_forward_hook(hook_fns_forward) + # add hook here to get features: end + predictions = self.predictor(batch["images"]) - yield predictions, batch["targets"].cpu().numpy(), batch["camids"].cpu().numpy() + + # add hook here to get features: start + handle.remove() + sz = list(batch["images"].shape[-2:]) + acts = self.get_actmap(act_outputs[0], sz) + # add hook here to get features: end + + yield predictions, batch["targets"].cpu().numpy(), batch["camids"].cpu().numpy(), acts class AsyncPredictor: From aadfa92943ea250b92e2f3d85a1c50413dd38eb2 Mon Sep 17 00:00:00 2001 From: TelBotDev <77771760+TelBotDev@users.noreply.github.com> Date: Wed, 23 Jun 2021 10:24:17 +0800 Subject: [PATCH 4/5] Update visualize_result.py add actmap --- demo/visualize_result.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/demo/visualize_result.py b/demo/visualize_result.py index 6901addf6..993c13637 100644 --- a/demo/visualize_result.py +++ b/demo/visualize_result.py @@ -55,6 +55,11 @@ def get_parser(): action='store_true', help='if use multiprocess for feature extraction.' ) + parser.add_argument( + '--actmap', + action='store_true', + help='if use activation map to overlap the image.' + ) parser.add_argument( "--dataset-name", help="a test dataset name for visualizing ranking list." @@ -111,10 +116,12 @@ def get_parser(): feats = [] pids = [] camids = [] - for (feat, pid, camid) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader)): + acts_list = [] + for (feat, pid, camid, acts) in tqdm.tqdm(demo.run_on_loader(test_loader), total=len(test_loader)): feats.append(feat) pids.extend(pid) camids.extend(camid) + acts_list.extend(acts) feats = torch.cat(feats, dim=0) q_feat = feats[:num_query] @@ -133,7 +140,7 @@ def get_parser(): logger.info("Finish computing APs for all query images!") visualizer = Visualizer(test_loader.dataset) - visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids) + visualizer.get_model_output(all_ap, distmat, q_pids, g_pids, q_camids, g_camids, acts_list) logger.info("Start saving ROC curve ...") fpr, tpr, pos, neg = visualizer.vis_roc_curve(args.output) @@ -142,5 +149,5 @@ def get_parser(): logger.info("Saving rank list result ...") query_indices = visualizer.vis_rank_list(args.output, args.vis_label, args.num_vis, - args.rank_sort, args.label_sort, args.max_rank) + args.rank_sort, args.label_sort, args.max_rank, args.actmap) logger.info("Finish saving rank list results!") From c7ced7a5d1571a8d99b28a2ff3fa9e6f4e2ec8e6 Mon Sep 17 00:00:00 2001 From: TelBotDev <77771760+TelBotDev@users.noreply.github.com> Date: Wed, 23 Jun 2021 10:35:57 +0800 Subject: [PATCH 5/5] add actmap add actmap --- demo/README.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/demo/README.md b/demo/README.md index 572e51f2e..66bad69ea 100644 --- a/demo/README.md +++ b/demo/README.md @@ -7,4 +7,12 @@ You can run this command to get cosine similarites between different images ```bash cd demo/ sh run_demo.sh -``` \ No newline at end of file +``` + +What is more, you can use this command to make thing more interesting +```bash +export CUDA_VISIBLE_DEVICES=0 +python3 demo/visualize_result.py --config-file ./configs/VeRi/sbs_R50-ibn.yml --actmap --dataset-name 'VeRi' --output logs/veri/sbs_R50-ibn/eval --opts MODEL.WEIGHTS logs/veri/sbs_R50-ibn/model_best.pth +``` +![4](https://user-images.githubusercontent.com/77771760/123026335-90dd8780-d40e-11eb-8a8d-1683dc19a05a.jpg) +where `--actmap` is used to add activation map upon the original image.