Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Sanjay's vision features cache script #4633

Merged
merged 18 commits into from
Sep 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions allennlp/common/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from requests.adapters import HTTPAdapter
from requests.exceptions import ConnectionError
from requests.packages.urllib3.util.retry import Retry
import lmdb
import lmdb

from allennlp.common.tqdm import Tqdm

Expand Down Expand Up @@ -350,7 +350,6 @@ def _serialize(data):
buffer = pickle.dumps(data, protocol=-1)
return np.frombuffer(buffer, dtype=np.uint8)


class TensorCache:
def __init__(
self,
Expand Down
1 change: 0 additions & 1 deletion allennlp/data/dataset_readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,3 @@
from allennlp.data.dataset_readers.babi import BabiReader
from allennlp.data.dataset_readers.text_classification_json import TextClassificationJsonReader
from allennlp.data.dataset_readers.nlvr2 import Nlvr2Reader
from allennlp.data.dataset_readers.vqav2 import VQAv2Reader
2 changes: 1 addition & 1 deletion allennlp/data/dataset_readers/nlvr2.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def text_to_instance(
with torch.no_grad():
images = images.to(self.cuda_device)
sizes = sizes.to(self.cuda_device)
featurized_images = self.image_featurizer(images)
featurized_images = self.image_featurizer(images, sizes)
detector_results = self.region_detector(images, sizes, featurized_images)
features = detector_results["features"]
coordinates = detector_results["coordinates"]
Expand Down
11 changes: 6 additions & 5 deletions allennlp/data/image_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Union, Sequence, Optional, Tuple

import torch
from torch import FloatTensor, IntTensor
from torch import FloatTensor, IntTensor, ByteTensor

from allennlp.common.registrable import Registrable

Expand Down Expand Up @@ -66,14 +66,15 @@ def __init__(
raise ValueError("Unknown type of `config`")

self.mapper = pipeline.mapper
self.model = pipeline.model

def load(self, filenames: ManyPaths) -> ImagesWithSize:
images = [{"file_name": str(f)} for f in filenames]
images = [self.mapper(i) for i in images]
processed_images = self.model.preprocess_image(images)

from detectron2.structures import ImageList
images = ImageList.from_tensors([image['image'] for image in images])

return (
processed_images.tensor,
torch.tensor(processed_images.image_sizes, dtype=torch.int32),
images.tensor.float() / 256,
torch.tensor(images.image_sizes, dtype=torch.int32)
)
33 changes: 15 additions & 18 deletions allennlp/modules/vision/grid_embedder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torch import nn, FloatTensor
from torch import nn, FloatTensor, IntTensor, ByteTensor

from allennlp.common.registrable import Registrable

Expand All @@ -14,7 +14,7 @@ class GridEmbedder(nn.Module, Registrable):
of the patch. The size of the image might change during this operation.
"""

def forward(self, images: FloatTensor) -> FloatTensor:
def forward(self, images: FloatTensor, sizes: IntTensor) -> FloatTensor:
raise NotImplementedError()

def get_output_dim(self) -> int:
Expand All @@ -37,7 +37,7 @@ def get_stride(self) -> int:
class NullGridEmbedder(GridEmbedder):
"""A `GridEmbedder` that returns the input image as given."""

def forward(self, images: FloatTensor) -> FloatTensor:
def forward(self, images: FloatTensor, sizes: IntTensor) -> FloatTensor:
return images

def get_output_dim(self) -> int:
Expand Down Expand Up @@ -77,25 +77,22 @@ def __init__(
width_per_group=width_per_group,
depth=depth,
)
self.device = device
self.gpu = None
# set the gpu device here.
if torch.distributed.is_initialized():
self.gpu = torch.distributed.get_rank()

pipeline = detectron.get_pipeline_from_flat_parameters(flat_parameters, make_copy=False)
self.preprocessor = pipeline.model.preprocess_image
self.backbone = pipeline.model.backbone

def forward(self, images: FloatTensor) -> FloatTensor:

# move images into gpu if needed.
if self.device == 'cuda':
if self.gpu is not None:
images = images.cuda(self.gpu)
else:
images = images.cuda()

result = self.backbone(images)
def forward(self, images: FloatTensor, sizes: IntTensor) -> FloatTensor:
images = [
{
"image": (image[:, :height, :width] * 256).byte(),
"height": height,
"width": width
}
for image, (height, width) in zip(images, sizes)
]
images = self.preprocessor(images) # This returns tensors on the correct device.
result = self.backbone(images.tensor)
assert len(result) == 1
return next(iter(result.values()))

Expand Down
18 changes: 13 additions & 5 deletions allennlp/modules/vision/region_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,18 @@ def forward(
self, raw_images: FloatTensor, image_sizes: IntTensor, featurized_images: FloatTensor
) -> Dict[str, torch.Tensor]:
batch_size = len(image_sizes)
# RPN
from detectron2.structures import ImageList

image_list = ImageList(raw_images, [(image[0], image[1]) for image in image_sizes])
raw_images = [
{
"image": (image[:, :height, :width] * 256).byte(),
"height": height,
"width": width
}
for image, (height, width) in zip(raw_images, image_sizes)
]
image_list = self.model.preprocess_image(raw_images)

# RPN
assert len(self.model.proposal_generator.in_features) == 1
featurized_images_in_dict = {
self.model.proposal_generator.in_features[0]: featurized_images
Expand All @@ -127,7 +135,7 @@ def forward(

predictions = self.model.roi_heads.box_predictor(pooled_features)

# class probablity
# class probability
cls_probs = F.softmax(predictions[0], dim=-1)
cls_probs = cls_probs[:, :-1] # background is last

Expand All @@ -138,7 +146,7 @@ def forward(
batch_coordinates = []
batch_features = []
batch_probs = []
batch_num_detections = torch.zeros(batch_size, device=raw_images.device, dtype=torch.int16)
batch_num_detections = torch.zeros(batch_size, device=image_list.tensor.device, dtype=torch.int16)
feature_dim = pooled_features.size(-1)
num_classes = cls_probs.size(-1)

Expand Down
71 changes: 71 additions & 0 deletions scripts/cache_vision_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
import glob
import argparse
from os import PathLike
from typing import Union, List

from tqdm import tqdm

import torch
from torchvision.datasets.folder import IMG_EXTENSIONS

from allennlp.common.file_utils import TensorCache
from allennlp.data import DetectronImageLoader
from allennlp.modules.vision import ResnetBackbone, FasterRcnnRegionDetector


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--image-dir", type=str, required=True)
parser.add_argument("--cache-dir", type=str, required=True)
parser.add_argument("--use-cuda", action="store_true", help="use GPU if one is available")
parser.add_argument("--batch-size", type=int, default=16)
args = parser.parse_args()
return args


if __name__ == "__main__":
args = parse_args()
os.makedirs(args.cache_dir, exist_ok=True)
features_cache = TensorCache(os.path.join(args.cache_dir, "features"))
coordinates_cache = TensorCache(os.path.join(args.cache_dir, "coordinates"))
image_paths = []
for extension in IMG_EXTENSIONS:
extension = extension.lstrip(
"."
) # Some versions of detectron have the period. Others do not.
image_paths += list(
glob.iglob(os.path.join(args.image_dir, "**", "*." + extension), recursive=True)
)

image_loader = DetectronImageLoader()
image_featurizer = ResnetBackbone()
region_detector = FasterRcnnRegionDetector()
if torch.cuda.is_available() and args.use_cuda:
image_featurizer.cuda()
region_detector.cuda()

def process_batch(batch: List[Union[str, PathLike]]):
batch_images, batch_shapes = image_loader(batch)
with torch.no_grad():
featurized_images = image_featurizer(batch_images, batch_shapes)
detector_results = region_detector(batch_images, batch_shapes, featurized_images)
features = detector_results["features"]
coordinates = detector_results["coordinates"]
for filename, image_features, image_coordinates in zip(batch, features, coordinates):
filename = os.path.basename(filename)
features_cache[filename] = features.cpu()
coordinates_cache[filename] = coordinates.cpu()

image_path_batch = []
for image_path in tqdm(image_paths, desc="Processing images"):
key = os.path.basename(image_path)
if key in features_cache and key in coordinates_cache:
continue
image_path_batch.append(image_path)

if len(image_path_batch) >= args.batch_size:
process_batch(image_path_batch)
image_path_batch.clear()
if len(image_path_batch) > 0:
process_batch(image_path_batch)