diff --git a/CHANGELOG.md b/CHANGELOG.md index e2058b2550b..deab8ed058a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 an actual `torch.nn.Module`. Other parameters to this method have changed as well. - Print the first batch to the console by default. - Renamed `sanity_checks` to `confidence_checks` (`sanity_checks` is deprecated and will be removed in AllenNLP 3.0). +- VilBERT backbone now rolls and unrolls extra dimensions to handle input with > 3 dimensions. ### Added diff --git a/allennlp/modules/backbones/vilbert_backbone.py b/allennlp/modules/backbones/vilbert_backbone.py index 0f554a7a1d2..3eeb9aad4ac 100644 --- a/allennlp/modules/backbones/vilbert_backbone.py +++ b/allennlp/modules/backbones/vilbert_backbone.py @@ -111,19 +111,50 @@ def forward( box_mask: torch.Tensor, text: TextFieldTensors, ) -> Dict[str, torch.Tensor]: - batch_size, _, feature_size = box_features.size() - if "token_ids" in text["tokens"]: token_ids = text["tokens"]["token_ids"] else: token_ids = text["tokens"]["tokens"] + if token_ids.shape[:-1] != box_features.shape[:-2]: + raise ValueError( + "Tokens and boxes must have the same batch size and extra " + "dimensions (if applicable). Token size {0} did not match " + "box feature size {1}.".format(token_ids.shape[:-1], box_features.shape[:-2]) + ) + # Shape: (batch_size, num_tokens) token_type_ids = text["tokens"].get("type_ids") # Shape: (batch_size, num_tokens) attention_mask = text["tokens"].get("mask") - # Shape: (batch_size, num_tokens, embedding_dim) + box_feature_dimensions = box_features.shape + feature_size = box_feature_dimensions[-1] + rolled_dimensions = box_feature_dimensions[:-2] + rolled_dimensions_product = 1 + for dim in rolled_dimensions: + rolled_dimensions_product *= dim + + token_ids = token_ids.view(rolled_dimensions_product, token_ids.shape[-1]) + if token_type_ids is not None: + token_type_ids = token_type_ids.view( + rolled_dimensions_product, token_type_ids.shape[-1] + ) + if attention_mask is not None: + attention_mask = attention_mask.view( + rolled_dimensions_product, attention_mask.shape[-1] + ) + box_features = box_features.view( + rolled_dimensions_product, box_feature_dimensions[-2], feature_size + ) + box_coordinates = box_coordinates.view( + rolled_dimensions_product, + box_coordinates.shape[-2], + box_coordinates.shape[-1], + ) + box_mask = box_mask.view(rolled_dimensions_product, box_mask.shape[-1]) + + # Shape: (rolled_dimensions_product, num_tokens, embedding_dim) embedding_output = self.text_embeddings(token_ids, token_type_ids) num_tokens = embedding_output.size(1) @@ -137,16 +168,16 @@ def forward( extended_image_attention_mask = box_mask - # Shape: (batch_size, feature_size, num_tokens) + # Shape: (rolled_dimensions_product, feature_size, num_tokens) # TODO (epwalsh): Why all zeros?? This doesn't seem right. extended_co_attention_mask = torch.zeros( - batch_size, + extended_image_attention_mask.shape[0], feature_size, num_tokens, dtype=extended_image_attention_mask.dtype, ) - # Shape: (batch_size, num_boxes, image_embedding_dim) + # Shape: (rolled_dimensions_product, num_boxes, image_embedding_dim) v_embedding_output = self.image_embeddings(box_features, box_coordinates) encoded_layers_t, encoded_layers_v = self.encoder( @@ -157,16 +188,25 @@ def forward( extended_co_attention_mask, ) - # Shape: (batch_size, num_tokens, embedding_dim) + # Shape: (rolled_dimensions_product, num_tokens, embedding_dim) sequence_output_t = encoded_layers_t[:, :, :, -1] - # Shape: (batch_size, num_boxes, image_embedding_dim) + # Shape: (rolled_dimensions_product, num_boxes, image_embedding_dim) sequence_output_v = encoded_layers_v[:, :, :, -1] - # Shape: (batch_size, pooled_output_dim) + # Shape: (rolled_dimensions_product, pooled_output_dim) pooled_output_t = self.t_pooler(sequence_output_t) - # Shape: (batch_size, pooled_output_dim) + # Shape: (rolled_dimensions_product, pooled_output_dim) pooled_output_v = self.v_pooler(sequence_output_v) + sequence_output_t = sequence_output_t.view( + rolled_dimensions + (sequence_output_t.shape[-2], sequence_output_t.shape[-1]) + ) + sequence_output_v = sequence_output_v.view( + rolled_dimensions + (sequence_output_v.shape[-2], sequence_output_v.shape[-1]) + ) + pooled_output_t = pooled_output_t.view(rolled_dimensions + (pooled_output_t.shape[-1],)) + pooled_output_v = pooled_output_v.view(rolled_dimensions + (pooled_output_v.shape[-1],)) + if self.fusion_method == "sum": pooled_output = self.dropout(pooled_output_t + pooled_output_v) elif self.fusion_method == "mul":