Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Roll backbone #5229

Merged
merged 18 commits into from
May 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
an actual `torch.nn.Module`. Other parameters to this method have changed as well.
- Print the first batch to the console by default.
- Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0).
- VilBERT backbone now rolls and unrolls extra dimensions to handle input with > 3 dimensions.

### Added

Expand Down
60 changes: 50 additions & 10 deletions allennlp/modules/backbones/vilbert_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,50 @@ def forward(
box_mask: torch.Tensor,
text: TextFieldTensors,
) -> Dict[str, torch.Tensor]:
batch_size, _, feature_size = box_features.size()

if "token_ids" in text["tokens"]:
token_ids = text["tokens"]["token_ids"]
else:
token_ids = text["tokens"]["tokens"]

if token_ids.shape[:-1] != box_features.shape[:-2]:
raise ValueError(
"Tokens and boxes must have the same batch size and extra "
"dimensions (if applicable). Token size {0} did not match "
"box feature size {1}.".format(token_ids.shape[:-1], box_features.shape[:-2])
)

# Shape: (batch_size, num_tokens)
token_type_ids = text["tokens"].get("type_ids")
# Shape: (batch_size, num_tokens)
attention_mask = text["tokens"].get("mask")

# Shape: (batch_size, num_tokens, embedding_dim)
box_feature_dimensions = box_features.shape
feature_size = box_feature_dimensions[-1]
rolled_dimensions = box_feature_dimensions[:-2]
rolled_dimensions_product = 1
for dim in rolled_dimensions:
rolled_dimensions_product *= dim
jacob-morrison marked this conversation as resolved.
Show resolved Hide resolved

token_ids = token_ids.view(rolled_dimensions_product, token_ids.shape[-1])
if token_type_ids is not None:
token_type_ids = token_type_ids.view(
rolled_dimensions_product, token_type_ids.shape[-1]
)
if attention_mask is not None:
attention_mask = attention_mask.view(
rolled_dimensions_product, attention_mask.shape[-1]
)
box_features = box_features.view(
rolled_dimensions_product, box_feature_dimensions[-2], feature_size
)
box_coordinates = box_coordinates.view(
rolled_dimensions_product,
box_coordinates.shape[-2],
box_coordinates.shape[-1],
)
box_mask = box_mask.view(rolled_dimensions_product, box_mask.shape[-1])

# Shape: (rolled_dimensions_product, num_tokens, embedding_dim)
embedding_output = self.text_embeddings(token_ids, token_type_ids)
num_tokens = embedding_output.size(1)

Expand All @@ -137,16 +168,16 @@ def forward(

extended_image_attention_mask = box_mask

# Shape: (batch_size, feature_size, num_tokens)
# Shape: (rolled_dimensions_product, feature_size, num_tokens)
# TODO (epwalsh): Why all zeros?? This doesn't seem right.
extended_co_attention_mask = torch.zeros(
batch_size,
extended_image_attention_mask.shape[0],
feature_size,
num_tokens,
dtype=extended_image_attention_mask.dtype,
)

# Shape: (batch_size, num_boxes, image_embedding_dim)
# Shape: (rolled_dimensions_product, num_boxes, image_embedding_dim)
v_embedding_output = self.image_embeddings(box_features, box_coordinates)

encoded_layers_t, encoded_layers_v = self.encoder(
Expand All @@ -157,16 +188,25 @@ def forward(
extended_co_attention_mask,
)

# Shape: (batch_size, num_tokens, embedding_dim)
# Shape: (rolled_dimensions_product, num_tokens, embedding_dim)
sequence_output_t = encoded_layers_t[:, :, :, -1]
# Shape: (batch_size, num_boxes, image_embedding_dim)
# Shape: (rolled_dimensions_product, num_boxes, image_embedding_dim)
sequence_output_v = encoded_layers_v[:, :, :, -1]

# Shape: (batch_size, pooled_output_dim)
# Shape: (rolled_dimensions_product, pooled_output_dim)
pooled_output_t = self.t_pooler(sequence_output_t)
# Shape: (batch_size, pooled_output_dim)
# Shape: (rolled_dimensions_product, pooled_output_dim)
pooled_output_v = self.v_pooler(sequence_output_v)

sequence_output_t = sequence_output_t.view(
rolled_dimensions + (sequence_output_t.shape[-2], sequence_output_t.shape[-1])
)
sequence_output_v = sequence_output_v.view(
rolled_dimensions + (sequence_output_v.shape[-2], sequence_output_v.shape[-1])
)
pooled_output_t = pooled_output_t.view(rolled_dimensions + (pooled_output_t.shape[-1],))
pooled_output_v = pooled_output_v.view(rolled_dimensions + (pooled_output_v.shape[-1],))

if self.fusion_method == "sum":
pooled_output = self.dropout(pooled_output_t + pooled_output_v)
elif self.fusion_method == "mul":
Expand Down