diff --git a/mmf/configs/models/mmbt/defaults.yaml b/mmf/configs/models/mmbt/defaults.yaml index 276318cdf9..4eb43b2721 100644 --- a/mmf/configs/models/mmbt/defaults.yaml +++ b/mmf/configs/models/mmbt/defaults.yaml @@ -8,6 +8,7 @@ model_config: freeze_modal: false freeze_complete_base: false finetune_lr_multiplier: 1 + fused_feature_only: false # Dimension of the embedding finally returned by the modal encoder modal_hidden_size: 2048 # Dimension of the embedding finally returned by the text encoder diff --git a/mmf/models/base_model.py b/mmf/models/base_model.py index 265f8378fc..80c277b5d8 100644 --- a/mmf/models/base_model.py +++ b/mmf/models/base_model.py @@ -43,12 +43,15 @@ def forward(self, sample_list): import collections import warnings from copy import deepcopy +from dataclasses import dataclass +from typing import Union from mmf.common.registry import registry from mmf.common.sample import to_device from mmf.modules.losses import Losses from mmf.utils.checkpoint import load_pretrained_model from mmf.utils.download import download_pretrained_model +from omegaconf import MISSING, DictConfig, OmegaConf from torch import nn @@ -63,8 +66,16 @@ class BaseModel(nn.Module): """ - def __init__(self, config): + @dataclass + class Config: + # Name of the model that is used in registry + model: str = MISSING + + def __init__(self, config: Union[DictConfig, Config]): super().__init__() + if not isinstance(config, DictConfig) and isinstance(config, self.Config): + config = OmegaConf.structured(config) + self.config = config self._logged_warning = {"losses_present": False} self._is_pretrained = False diff --git a/mmf/models/m4c.py b/mmf/models/m4c.py index c58f6382f2..039c4be9b6 100644 --- a/mmf/models/m4c.py +++ b/mmf/models/m4c.py @@ -122,7 +122,7 @@ def _build_ocr_encoding(self): # OCR appearance feature: Faster R-CNN self.ocr_faster_rcnn_fc7 = build_image_encoder( - self._build_encoder_config, direct_features=True + self._build_encoder_config(), direct_features=True ) self.finetune_modules.append( {"module": self.ocr_faster_rcnn_fc7, "lr_scale": self.config.lr_scale_frcn} diff --git a/mmf/models/mmbt.py b/mmf/models/mmbt.py index dc4c3d05a8..2fbd6431e5 100644 --- a/mmf/models/mmbt.py +++ b/mmf/models/mmbt.py @@ -7,24 +7,34 @@ import os from copy import deepcopy -from typing import Dict, Optional +from dataclasses import dataclass +from typing import Dict, Optional, Union import torch from mmf.common.registry import registry from mmf.models.base_model import BaseModel from mmf.models.interfaces.mmbt import MMBTGridHMInterface -from mmf.modules.encoders import MultiModalEncoderBase +from mmf.modules.encoders import ( + Encoder, + ImageEncoder, + ImageEncoderTypes, + MultiModalEncoderBase, + ResNet152ImageEncoder, + TextEncoder, + TextEncoderTypes, + TransformerEncoder, +) from mmf.modules.hf_layers import replace_with_jit from mmf.utils.checkpoint import load_pretrained_model from mmf.utils.configuration import get_mmf_cache_dir from mmf.utils.modeling import get_optimizer_parameters_for_bert -from omegaconf import OmegaConf +from omegaconf import II, DictConfig, OmegaConf from torch import Tensor, nn from transformers.modeling_bert import BertForPreTraining, BertPredictionHeadTransform # TODO: Remove after transformers package upgrade to 2.5 -class MMBTConfig: +class MMBTConfigForTransformers: """Configuration class to store the configuration of a `MMBT Model`. Args: config (:obj:`~transformers.PreTrainedConfig`): @@ -314,7 +324,7 @@ def build(self): text_encoder, modal_encoder = encoders[0], encoders[1] self._encoder_config = text_encoder.config - self._mmbt_config = MMBTConfig( + self._mmbt_config = MMBTConfigForTransformers( self._encoder_config, num_labels=self.config.num_labels, modal_hidden_size=self.config.modal_hidden_size, @@ -328,7 +338,10 @@ def build(self): def forward(self, sample_list: Dict[str, Tensor]): if self._is_direct_features_input: - input_modal = sample_list["image_feature_0"] + if "input_modal" in sample_list: + input_modal = sample_list["input_modal"] + else: + input_modal = sample_list["image_feature_0"] else: input_modal = sample_list["image"] @@ -352,14 +365,15 @@ def forward(self, sample_list: Dict[str, Tensor]): # If max_id is greater than 0, that means text is at 0 segment # which means modal will be at 1 # In other case, it will be zero, which it already is - if max_id == 0: + # NOTE: We compare with tensor here due to TorchScript compliance + if max_id == torch.tensor(0, dtype=max_id.dtype): token_value = 1 else: max_segment = self.num_max_segment - 1 # If max id is not equal to max_segment, it means # text segments start from 0 which means modal will # be last, otherwise, it is 0, which it already is - if max_id != max_segment: + if max_id != torch.tensor(max_segment, dtype=max_id.dtype): token_value = max_segment modal_token_type_ids = torch.full( (input_modal.size(0), 1), @@ -368,6 +382,10 @@ def forward(self, sample_list: Dict[str, Tensor]): device=input_modal.device, ) + # In case of XRAY, there might be only two dims + if input_modal.dim() == 2: + input_modal = input_modal.unsqueeze(dim=1) + # See details of inputs at # https://github.com/huggingface/transformers/blob/1789c7/src/transformers/modeling_mmbt.py#L101 # noqa output = self.mmbt( @@ -474,6 +492,7 @@ def __init__(self, config, *args, **kwargs): self.num_labels = self.config.num_labels self.output_hidden_states = self.encoder_config.output_hidden_states self.output_attentions = self.encoder_config.output_attentions + self.fused_feature_only = self.config.fused_feature_only self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob) self.classifier = nn.Sequential( @@ -495,6 +514,11 @@ def forward(self, sample_list: Dict[str, Tensor]): ), "output_attentions or output_hidden_states not supported in script mode" pooled_output = self.dropout(pooled_output) + + if self.fused_feature_only: + output["fused_feature"] = self.classifier[0](pooled_output) + return output + logits = self.classifier(pooled_output) reshaped_logits = logits.contiguous().view(-1, self.num_labels) output["scores"] = reshaped_logits @@ -504,9 +528,41 @@ def forward(self, sample_list: Dict[str, Tensor]): @registry.register_model("mmbt") class MMBT(BaseModel): - def __init__(self, config): + @dataclass + class Config(BaseModel.Config): + model: str = "mmbt" + # classification or pretraining + training_head_type: str = "pretraining" + bert_model_name: str = "bert-base-uncased" + direct_features_input: bool = False + freeze_text: bool = False + freeze_modal: bool = False + freeze_complete_base: bool = False + finetune_lr_multiplier: float = 1 + # Dimension of the embedding finally returned by the modal encoder + modal_hidden_size: int = 2048 + text_hidden_size: int = 768 + num_labels: int = 2 + # This actually is Union[ImageEncoderConfig, ImageFeatureEncoderConfig] + modal_encoder: Encoder.Config = ImageEncoder.Config( + type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config() + ) + text_encoder: Encoder.Config = TextEncoder.Config( + type=TextEncoderTypes.transformer, + params=TransformerEncoder.Config(bert_model_name=II("bert_model_name")), + ) + use_modal_start_token: bool = True + use_modal_end_token: bool = True + fused_feature_only: bool = False + output_dim: int = 768 + + def __init__(self, config: Union[DictConfig, Config], *args, **kwargs): super().__init__(config) + @classmethod + def from_params(cls, **kwargs): + return MMBT(OmegaConf.structured(cls.Config(**kwargs))) + def build(self): if self.config.training_head_type == "pretraining": self.model = MMBTForPreTraining(self.config) diff --git a/mmf/modules/encoders.py b/mmf/modules/encoders.py index f4208327cc..0c2c9ec7c6 100644 --- a/mmf/modules/encoders.py +++ b/mmf/modules/encoders.py @@ -2,6 +2,8 @@ import os import pickle from copy import deepcopy +from dataclasses import dataclass +from enum import Enum import torch import torchvision @@ -12,16 +14,49 @@ from mmf.utils.download import download_pretrained_model from mmf.utils.file_io import PathManager from mmf.utils.general import get_absolute_path -from omegaconf import OmegaConf +from omegaconf import MISSING, OmegaConf from torch import nn from transformers.configuration_auto import AutoConfig from transformers.modeling_auto import AutoModel -class ImageFeatureEncoder(nn.Module): - def __init__(self, config, *args, **kwargs): +# SuperClass for encoder params +@dataclass +class EncoderParams: + pass + + +class Encoder(nn.Module): + @dataclass + class Config: + type: str = MISSING + params: EncoderParams = EncoderParams() + + +class ImageFeatureEncoderTypes(Enum): + default = "default" + identity = "identity" + projection = "projection" + frcnn_fc7 = "finetune_faster_rcnn_fpn_fc7" + + +@dataclass +class ImageFeatureEncoderBaseParams(EncoderParams): + in_dim: int = MISSING + + +class ImageFeatureEncoder(Encoder): + @dataclass + class Config(Encoder.Config): + type: ImageFeatureEncoderTypes = MISSING + params: EncoderParams = ImageFeatureEncoderBaseParams() + + def __init__(self, config: Config, *args, **kwargs): super().__init__() encoder_type = config.type + if isinstance(encoder_type, ImageFeatureEncoderTypes): + encoder_type = encoder_type.value + assert ( "in_dim" in config.params ), "ImageFeatureEncoder require 'in_dim' param in config" @@ -81,13 +116,28 @@ def forward(self, image): return i3 -class ImageEncoder(nn.Module): - def __init__(self, config, *args, **kwargs): +class ImageEncoderTypes(Enum): + default = "default" + identity = "identity" + resnet152 = "resnet152" + + +class ImageEncoder(Encoder): + @dataclass + class Config(Encoder.Config): + type: ImageEncoderTypes = MISSING + params: EncoderParams = EncoderParams() + + def __init__(self, config: Config, *args, **kwargs): super().__init__() self._type = config.type + + if isinstance(self._type, ImageEncoderTypes): + self._type = self._type.value + params = config.params - if self._type == "default": + if self._type == "default" or self._type == "identity": self.module = nn.Identity() self.module.out_dim = params.in_dim elif self._type == "resnet152": @@ -105,7 +155,14 @@ def forward(self, image): # Taken from facebookresearch/mmbt with some modifications class ResNet152ImageEncoder(nn.Module): - def __init__(self, config, *args, **kwargs): + @dataclass + class Config(EncoderParams): + pretrained: bool = True + # "avg" or "adaptive" + pool_type: str = "avg" + num_output_features: int = 1 + + def __init__(self, config: Config, *args, **kwargs): super().__init__() self.config = config model = torchvision.models.resnet152(pretrained=config.get("pretrained", True)) @@ -140,10 +197,24 @@ def forward(self, x): return out # BxNx2048 +class TextEncoderTypes(Enum): + identity = "identity" + transformer = "transformer" + embedding = "embedding" + + class TextEncoder(nn.Module): - def __init__(self, config, *args, **kwargs): + @dataclass + class Config(Encoder.Config): + # identity, transformer or embedding as of now + type: TextEncoderTypes = MISSING + params: EncoderParams = EncoderParams() + + def __init__(self, config: Config, *args, **kwargs): super().__init__() self._type = config.type + if isinstance(self._type, TextEncoderTypes): + self._type = self._type.value if self._type == "identity": self.module = nn.Identity() @@ -182,7 +253,26 @@ def forward(self, x): class TransformerEncoder(nn.Module): - def __init__(self, config, *args, **kwargs): + @dataclass + class Config(EncoderParams): + num_segments: int = 2 + bert_model_name: str = "bert-base-uncased" + # Options below can be overridden to update the bert configuration used + # to initialize the bert encoder. If some option is missing or + # if you are using an encoder different then BERT, add extra parameters + # by inheriting and extending this config + # Those options will automatically override the options for your transformer + # encoder's configuration. For e.g. vocab_size is missing here, just add + # vocab_size: x to update the size of the vocabulary with which encoder is + # initialized. If you update the default values, the transformer you + # will get will be initialized from scratch. + hidden_size: int = 768 + num_hidden_layers: int = 12 + num_attention_heads: int = 12 + output_attentions: bool = False + output_hidden_states: bool = False + + def __init__(self, config: Config, *args, **kwargs): super().__init__() self.config = config self.module = AutoModel.from_pretrained( @@ -209,7 +299,7 @@ def _init_segment_embeddings(self): ) self.embeddings.token_type_embeddings = new_embeds - def _build_encoder_config(self, config): + def _build_encoder_config(self, config: Config): return AutoConfig.from_pretrained( self.config.bert_model_name, **OmegaConf.to_container(self.config) ) @@ -222,7 +312,16 @@ def forward(self, *args, **kwargs): class MultiModalEncoderBase(nn.Module): __jit_unused_properties__ = ["encoder_config"] - def __init__(self, config, *args, **kwargs): + @dataclass + class Config(EncoderParams): + # This actually is Union[ImageEncoderConfig, ImageFeatureEncoderConfig] + modal_encoder: Encoder.Config = ResNet152ImageEncoder.Config() + text_encoder: TextEncoder.Config = TransformerEncoder.Config() + direct_features_input: bool = False + modal_hidden_size: int = 2048 + text_hidden_size: int = 768 + + def __init__(self, config: Config, *args, **kwargs): super().__init__() self.config = config diff --git a/mmf/utils/build.py b/mmf/utils/build.py index 719692798a..899ad02fd3 100644 --- a/mmf/utils/build.py +++ b/mmf/utils/build.py @@ -2,8 +2,9 @@ import os import warnings -from typing import Any, Dict, Type +from typing import Any, Dict, Type, Union +import mmf import torch from mmf.common import typings as mmf_typings from mmf.common.registry import registry @@ -11,7 +12,7 @@ from mmf.utils.configuration import Configuration from mmf.utils.distributed import is_dist_initialized from mmf.utils.general import get_optimizer_parameters -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf ProcessorType = Type[Processor] @@ -57,9 +58,16 @@ def build_trainer(config: mmf_typings.DictConfig) -> Any: return trainer_obj -def build_model(config): - model_name = config.model +def build_model( + config: Union[DictConfig, "mmf.models.base_model.BaseModel.Config"] +) -> "mmf.models.base_model.BaseModel": + from mmf.models.base_model import BaseModel + + # If it is not an OmegaConf object, create the object + if not isinstance(config, DictConfig) and isinstance(config, BaseModel.Config): + config = OmegaConf.structured(config) + model_name = config.model model_class = registry.get_model_class(model_name) if model_class is None: diff --git a/tests/models/test_mmbt.py b/tests/models/test_mmbt.py index ead80a47a9..d1209a3d72 100644 --- a/tests/models/test_mmbt.py +++ b/tests/models/test_mmbt.py @@ -3,13 +3,21 @@ import io import unittest +import tests.test_utils as test_utils import torch from mmf.common.registry import registry from mmf.common.sample import Sample, SampleList +from mmf.models.mmbt import MMBT +from mmf.modules.encoders import ( + ImageEncoder, + ImageEncoderTypes, + ResNet152ImageEncoder, + TextEncoder, + TextEncoderTypes, +) from mmf.utils.configuration import Configuration from mmf.utils.env import setup_imports - -import tests.test_utils as test_utils +from omegaconf import OmegaConf class TestMMBTTorchscript(unittest.TestCase): @@ -53,3 +61,48 @@ def test_finetune_model(self): script_output = script_model(test_sample_list) self.assertTrue(torch.equal(model_output["scores"], script_output["scores"])) + + +class TestMMBTConfig(unittest.TestCase): + def test_mmbt_from_params(self): + # default init + mmbt = MMBT.from_params( + modal_encoder=ImageEncoder.Config( + type=ImageEncoderTypes.resnet152, + params=ResNet152ImageEncoder.Config(pretrained=False), + ), + text_encoder=TextEncoder.Config(type=TextEncoderTypes.identity), + ) + + config = OmegaConf.structured( + MMBT.Config( + modal_encoder=ImageEncoder.Config( + type=ImageEncoderTypes.resnet152, + params=ResNet152ImageEncoder.Config(pretrained=False), + ), + text_encoder=TextEncoder.Config(type=TextEncoderTypes.identity), + ) + ) + self.assertIsNotNone(mmbt) + # Make sure that the config is created from MMBT.Config + self.assertEqual(mmbt.config, config) + + @test_utils.skip_if_no_network + def test_mmbt_pretrained(self): + mmbt = MMBT.from_params() + self.assertIsNotNone(mmbt) + + def test_mmbt_directly_from_config(self): + config = OmegaConf.structured( + MMBT.Config( + modal_encoder=ImageEncoder.Config( + type=ImageEncoderTypes.resnet152, + params=ResNet152ImageEncoder.Config(pretrained=False), + ), + text_encoder=TextEncoder.Config(type=TextEncoderTypes.identity), + ) + ) + mmbt = MMBT(config) + self.assertIsNotNone(mmbt) + # Make sure that the config is created from MMBT.Config + self.assertEqual(mmbt.config, config)