From 6b271de5d3cd43b4e33679da18667f5310a91780 Mon Sep 17 00:00:00 2001 From: Amanpreet Singh Date: Wed, 7 Oct 2020 12:04:59 -0700 Subject: [PATCH] [feat] Add init dataclasses for mmbt and encoders (#565) Summary: Pull Request resolved: https://github.com/facebookresearch/mmf/pull/565 As step one of FAIM integration, we allow building our models from config so that the models are purely decoupled from configuration and users know what the model expects. We first do this for the MMBT model. - Adds configs for MMBT and respective encoders. - Also adds from_params method for MMBT - Updates build method to support passing of direct config object - Add Config class to BaseModel as well and update typings There is an issue with OmegaConf that doesn't let us use Union in structured configs. Take a look at https://github.com/omry/omegaconf/issues/144 Differential Revision: D23699688 fbshipit-source-id: 3f0f43938814a1c6bfdef51c048bc7ee8ef2c28b --- mmf/configs/models/mmbt/defaults.yaml | 1 + mmf/models/base_model.py | 13 ++- mmf/models/m4c.py | 2 +- mmf/models/mmbt.py | 74 ++++++++++++++-- mmf/modules/encoders.py | 121 +++++++++++++++++++++++--- mmf/utils/build.py | 16 +++- tests/models/test_mmbt.py | 57 +++++++++++- 7 files changed, 256 insertions(+), 28 deletions(-) diff --git a/mmf/configs/models/mmbt/defaults.yaml b/mmf/configs/models/mmbt/defaults.yaml index 276318cdf9..4eb43b2721 100644 --- a/mmf/configs/models/mmbt/defaults.yaml +++ b/mmf/configs/models/mmbt/defaults.yaml @@ -8,6 +8,7 @@ model_config: freeze_modal: false freeze_complete_base: false finetune_lr_multiplier: 1 + fused_feature_only: false # Dimension of the embedding finally returned by the modal encoder modal_hidden_size: 2048 # Dimension of the embedding finally returned by the text encoder diff --git a/mmf/models/base_model.py b/mmf/models/base_model.py index 265f8378fc..80c277b5d8 100644 --- a/mmf/models/base_model.py +++ b/mmf/models/base_model.py @@ -43,12 +43,15 @@ def forward(self, sample_list): import collections import warnings from copy import deepcopy +from dataclasses import dataclass +from typing import Union from mmf.common.registry import registry from mmf.common.sample import to_device from mmf.modules.losses import Losses from mmf.utils.checkpoint import load_pretrained_model from mmf.utils.download import download_pretrained_model +from omegaconf import MISSING, DictConfig, OmegaConf from torch import nn @@ -63,8 +66,16 @@ class BaseModel(nn.Module): """ - def __init__(self, config): + @dataclass + class Config: + # Name of the model that is used in registry + model: str = MISSING + + def __init__(self, config: Union[DictConfig, Config]): super().__init__() + if not isinstance(config, DictConfig) and isinstance(config, self.Config): + config = OmegaConf.structured(config) + self.config = config self._logged_warning = {"losses_present": False} self._is_pretrained = False diff --git a/mmf/models/m4c.py b/mmf/models/m4c.py index c58f6382f2..039c4be9b6 100644 --- a/mmf/models/m4c.py +++ b/mmf/models/m4c.py @@ -122,7 +122,7 @@ def _build_ocr_encoding(self): # OCR appearance feature: Faster R-CNN self.ocr_faster_rcnn_fc7 = build_image_encoder( - self._build_encoder_config, direct_features=True + self._build_encoder_config(), direct_features=True ) self.finetune_modules.append( {"module": self.ocr_faster_rcnn_fc7, "lr_scale": self.config.lr_scale_frcn} diff --git a/mmf/models/mmbt.py b/mmf/models/mmbt.py index dc4c3d05a8..2fbd6431e5 100644 --- a/mmf/models/mmbt.py +++ b/mmf/models/mmbt.py @@ -7,24 +7,34 @@ import os from copy import deepcopy -from typing import Dict, Optional +from dataclasses import dataclass +from typing import Dict, Optional, Union import torch from mmf.common.registry import registry from mmf.models.base_model import BaseModel from mmf.models.interfaces.mmbt import MMBTGridHMInterface -from mmf.modules.encoders import MultiModalEncoderBase +from mmf.modules.encoders import ( + Encoder, + ImageEncoder, + ImageEncoderTypes, + MultiModalEncoderBase, + ResNet152ImageEncoder, + TextEncoder, + TextEncoderTypes, + TransformerEncoder, +) from mmf.modules.hf_layers import replace_with_jit from mmf.utils.checkpoint import load_pretrained_model from mmf.utils.configuration import get_mmf_cache_dir from mmf.utils.modeling import get_optimizer_parameters_for_bert -from omegaconf import OmegaConf +from omegaconf import II, DictConfig, OmegaConf from torch import Tensor, nn from transformers.modeling_bert import BertForPreTraining, BertPredictionHeadTransform # TODO: Remove after transformers package upgrade to 2.5 -class MMBTConfig: +class MMBTConfigForTransformers: """Configuration class to store the configuration of a `MMBT Model`. Args: config (:obj:`~transformers.PreTrainedConfig`): @@ -314,7 +324,7 @@ def build(self): text_encoder, modal_encoder = encoders[0], encoders[1] self._encoder_config = text_encoder.config - self._mmbt_config = MMBTConfig( + self._mmbt_config = MMBTConfigForTransformers( self._encoder_config, num_labels=self.config.num_labels, modal_hidden_size=self.config.modal_hidden_size, @@ -328,7 +338,10 @@ def build(self): def forward(self, sample_list: Dict[str, Tensor]): if self._is_direct_features_input: - input_modal = sample_list["image_feature_0"] + if "input_modal" in sample_list: + input_modal = sample_list["input_modal"] + else: + input_modal = sample_list["image_feature_0"] else: input_modal = sample_list["image"] @@ -352,14 +365,15 @@ def forward(self, sample_list: Dict[str, Tensor]): # If max_id is greater than 0, that means text is at 0 segment # which means modal will be at 1 # In other case, it will be zero, which it already is - if max_id == 0: + # NOTE: We compare with tensor here due to TorchScript compliance + if max_id == torch.tensor(0, dtype=max_id.dtype): token_value = 1 else: max_segment = self.num_max_segment - 1 # If max id is not equal to max_segment, it means # text segments start from 0 which means modal will # be last, otherwise, it is 0, which it already is - if max_id != max_segment: + if max_id != torch.tensor(max_segment, dtype=max_id.dtype): token_value = max_segment modal_token_type_ids = torch.full( (input_modal.size(0), 1), @@ -368,6 +382,10 @@ def forward(self, sample_list: Dict[str, Tensor]): device=input_modal.device, ) + # In case of XRAY, there might be only two dims + if input_modal.dim() == 2: + input_modal = input_modal.unsqueeze(dim=1) + # See details of inputs at # https://github.com/huggingface/transformers/blob/1789c7/src/transformers/modeling_mmbt.py#L101 # noqa output = self.mmbt( @@ -474,6 +492,7 @@ def __init__(self, config, *args, **kwargs): self.num_labels = self.config.num_labels self.output_hidden_states = self.encoder_config.output_hidden_states self.output_attentions = self.encoder_config.output_attentions + self.fused_feature_only = self.config.fused_feature_only self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob) self.classifier = nn.Sequential( @@ -495,6 +514,11 @@ def forward(self, sample_list: Dict[str, Tensor]): ), "output_attentions or output_hidden_states not supported in script mode" pooled_output = self.dropout(pooled_output) + + if self.fused_feature_only: + output["fused_feature"] = self.classifier[0](pooled_output) + return output + logits = self.classifier(pooled_output) reshaped_logits = logits.contiguous().view(-1, self.num_labels) output["scores"] = reshaped_logits @@ -504,9 +528,41 @@ def forward(self, sample_list: Dict[str, Tensor]): @registry.register_model("mmbt") class MMBT(BaseModel): - def __init__(self, config): + @dataclass + class Config(BaseModel.Config): + model: str = "mmbt" + # classification or pretraining + training_head_type: str = "pretraining" + bert_model_name: str = "bert-base-uncased" + direct_features_input: bool = False + freeze_text: bool = False + freeze_modal: bool = False + freeze_complete_base: bool = False + finetune_lr_multiplier: float = 1 + # Dimension of the embedding finally returned by the modal encoder + modal_hidden_size: int = 2048 + text_hidden_size: int = 768 + num_labels: int = 2 + # This actually is Union[ImageEncoderConfig, ImageFeatureEncoderConfig] + modal_encoder: Encoder.Config = ImageEncoder.Config( + type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config() + ) + text_encoder: Encoder.Config = TextEncoder.Config( + type=TextEncoderTypes.transformer, + params=TransformerEncoder.Config(bert_model_name=II("bert_model_name")), + ) + use_modal_start_token: bool = True + use_modal_end_token: bool = True + fused_feature_only: bool = False + output_dim: int = 768 + + def __init__(self, config: Union[DictConfig, Config], *args, **kwargs): super().__init__(config) + @classmethod + def from_params(cls, **kwargs): + return MMBT(OmegaConf.structured(cls.Config(**kwargs))) + def build(self): if self.config.training_head_type == "pretraining": self.model = MMBTForPreTraining(self.config) diff --git a/mmf/modules/encoders.py b/mmf/modules/encoders.py index f4208327cc..0c2c9ec7c6 100644 --- a/mmf/modules/encoders.py +++ b/mmf/modules/encoders.py @@ -2,6 +2,8 @@ import os import pickle from copy import deepcopy +from dataclasses import dataclass +from enum import Enum import torch import torchvision @@ -12,16 +14,49 @@ from mmf.utils.download import download_pretrained_model from mmf.utils.file_io import PathManager from mmf.utils.general import get_absolute_path -from omegaconf import OmegaConf +from omegaconf import MISSING, OmegaConf from torch import nn from transformers.configuration_auto import AutoConfig from transformers.modeling_auto import AutoModel -class ImageFeatureEncoder(nn.Module): - def __init__(self, config, *args, **kwargs): +# SuperClass for encoder params +@dataclass +class EncoderParams: + pass + + +class Encoder(nn.Module): + @dataclass + class Config: + type: str = MISSING + params: EncoderParams = EncoderParams() + + +class ImageFeatureEncoderTypes(Enum): + default = "default" + identity = "identity" + projection = "projection" + frcnn_fc7 = "finetune_faster_rcnn_fpn_fc7" + + +@dataclass +class ImageFeatureEncoderBaseParams(EncoderParams): + in_dim: int = MISSING + + +class ImageFeatureEncoder(Encoder): + @dataclass + class Config(Encoder.Config): + type: ImageFeatureEncoderTypes = MISSING + params: EncoderParams = ImageFeatureEncoderBaseParams() + + def __init__(self, config: Config, *args, **kwargs): super().__init__() encoder_type = config.type + if isinstance(encoder_type, ImageFeatureEncoderTypes): + encoder_type = encoder_type.value + assert ( "in_dim" in config.params ), "ImageFeatureEncoder require 'in_dim' param in config" @@ -81,13 +116,28 @@ def forward(self, image): return i3 -class ImageEncoder(nn.Module): - def __init__(self, config, *args, **kwargs): +class ImageEncoderTypes(Enum): + default = "default" + identity = "identity" + resnet152 = "resnet152" + + +class ImageEncoder(Encoder): + @dataclass + class Config(Encoder.Config): + type: ImageEncoderTypes = MISSING + params: EncoderParams = EncoderParams() + + def __init__(self, config: Config, *args, **kwargs): super().__init__() self._type = config.type + + if isinstance(self._type, ImageEncoderTypes): + self._type = self._type.value + params = config.params - if self._type == "default": + if self._type == "default" or self._type == "identity": self.module = nn.Identity() self.module.out_dim = params.in_dim elif self._type == "resnet152": @@ -105,7 +155,14 @@ def forward(self, image): # Taken from facebookresearch/mmbt with some modifications class ResNet152ImageEncoder(nn.Module): - def __init__(self, config, *args, **kwargs): + @dataclass + class Config(EncoderParams): + pretrained: bool = True + # "avg" or "adaptive" + pool_type: str = "avg" + num_output_features: int = 1 + + def __init__(self, config: Config, *args, **kwargs): super().__init__() self.config = config model = torchvision.models.resnet152(pretrained=config.get("pretrained", True)) @@ -140,10 +197,24 @@ def forward(self, x): return out # BxNx2048 +class TextEncoderTypes(Enum): + identity = "identity" + transformer = "transformer" + embedding = "embedding" + + class TextEncoder(nn.Module): - def __init__(self, config, *args, **kwargs): + @dataclass + class Config(Encoder.Config): + # identity, transformer or embedding as of now + type: TextEncoderTypes = MISSING + params: EncoderParams = EncoderParams() + + def __init__(self, config: Config, *args, **kwargs): super().__init__() self._type = config.type + if isinstance(self._type, TextEncoderTypes): + self._type = self._type.value if self._type == "identity": self.module = nn.Identity() @@ -182,7 +253,26 @@ def forward(self, x): class TransformerEncoder(nn.Module): - def __init__(self, config, *args, **kwargs): + @dataclass + class Config(EncoderParams): + num_segments: int = 2 + bert_model_name: str = "bert-base-uncased" + # Options below can be overridden to update the bert configuration used + # to initialize the bert encoder. If some option is missing or + # if you are using an encoder different then BERT, add extra parameters + # by inheriting and extending this config + # Those options will automatically override the options for your transformer + # encoder's configuration. For e.g. vocab_size is missing here, just add + # vocab_size: x to update the size of the vocabulary with which encoder is + # initialized. If you update the default values, the transformer you + # will get will be initialized from scratch. + hidden_size: int = 768 + num_hidden_layers: int = 12 + num_attention_heads: int = 12 + output_attentions: bool = False + output_hidden_states: bool = False + + def __init__(self, config: Config, *args, **kwargs): super().__init__() self.config = config self.module = AutoModel.from_pretrained( @@ -209,7 +299,7 @@ def _init_segment_embeddings(self): ) self.embeddings.token_type_embeddings = new_embeds - def _build_encoder_config(self, config): + def _build_encoder_config(self, config: Config): return AutoConfig.from_pretrained( self.config.bert_model_name, **OmegaConf.to_container(self.config) ) @@ -222,7 +312,16 @@ def forward(self, *args, **kwargs): class MultiModalEncoderBase(nn.Module): __jit_unused_properties__ = ["encoder_config"] - def __init__(self, config, *args, **kwargs): + @dataclass + class Config(EncoderParams): + # This actually is Union[ImageEncoderConfig, ImageFeatureEncoderConfig] + modal_encoder: Encoder.Config = ResNet152ImageEncoder.Config() + text_encoder: TextEncoder.Config = TransformerEncoder.Config() + direct_features_input: bool = False + modal_hidden_size: int = 2048 + text_hidden_size: int = 768 + + def __init__(self, config: Config, *args, **kwargs): super().__init__() self.config = config diff --git a/mmf/utils/build.py b/mmf/utils/build.py index 719692798a..899ad02fd3 100644 --- a/mmf/utils/build.py +++ b/mmf/utils/build.py @@ -2,8 +2,9 @@ import os import warnings -from typing import Any, Dict, Type +from typing import Any, Dict, Type, Union +import mmf import torch from mmf.common import typings as mmf_typings from mmf.common.registry import registry @@ -11,7 +12,7 @@ from mmf.utils.configuration import Configuration from mmf.utils.distributed import is_dist_initialized from mmf.utils.general import get_optimizer_parameters -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf ProcessorType = Type[Processor] @@ -57,9 +58,16 @@ def build_trainer(config: mmf_typings.DictConfig) -> Any: return trainer_obj -def build_model(config): - model_name = config.model +def build_model( + config: Union[DictConfig, "mmf.models.base_model.BaseModel.Config"] +) -> "mmf.models.base_model.BaseModel": + from mmf.models.base_model import BaseModel + + # If it is not an OmegaConf object, create the object + if not isinstance(config, DictConfig) and isinstance(config, BaseModel.Config): + config = OmegaConf.structured(config) + model_name = config.model model_class = registry.get_model_class(model_name) if model_class is None: diff --git a/tests/models/test_mmbt.py b/tests/models/test_mmbt.py index ead80a47a9..d1209a3d72 100644 --- a/tests/models/test_mmbt.py +++ b/tests/models/test_mmbt.py @@ -3,13 +3,21 @@ import io import unittest +import tests.test_utils as test_utils import torch from mmf.common.registry import registry from mmf.common.sample import Sample, SampleList +from mmf.models.mmbt import MMBT +from mmf.modules.encoders import ( + ImageEncoder, + ImageEncoderTypes, + ResNet152ImageEncoder, + TextEncoder, + TextEncoderTypes, +) from mmf.utils.configuration import Configuration from mmf.utils.env import setup_imports - -import tests.test_utils as test_utils +from omegaconf import OmegaConf class TestMMBTTorchscript(unittest.TestCase): @@ -53,3 +61,48 @@ def test_finetune_model(self): script_output = script_model(test_sample_list) self.assertTrue(torch.equal(model_output["scores"], script_output["scores"])) + + +class TestMMBTConfig(unittest.TestCase): + def test_mmbt_from_params(self): + # default init + mmbt = MMBT.from_params( + modal_encoder=ImageEncoder.Config( + type=ImageEncoderTypes.resnet152, + params=ResNet152ImageEncoder.Config(pretrained=False), + ), + text_encoder=TextEncoder.Config(type=TextEncoderTypes.identity), + ) + + config = OmegaConf.structured( + MMBT.Config( + modal_encoder=ImageEncoder.Config( + type=ImageEncoderTypes.resnet152, + params=ResNet152ImageEncoder.Config(pretrained=False), + ), + text_encoder=TextEncoder.Config(type=TextEncoderTypes.identity), + ) + ) + self.assertIsNotNone(mmbt) + # Make sure that the config is created from MMBT.Config + self.assertEqual(mmbt.config, config) + + @test_utils.skip_if_no_network + def test_mmbt_pretrained(self): + mmbt = MMBT.from_params() + self.assertIsNotNone(mmbt) + + def test_mmbt_directly_from_config(self): + config = OmegaConf.structured( + MMBT.Config( + modal_encoder=ImageEncoder.Config( + type=ImageEncoderTypes.resnet152, + params=ResNet152ImageEncoder.Config(pretrained=False), + ), + text_encoder=TextEncoder.Config(type=TextEncoderTypes.identity), + ) + ) + mmbt = MMBT(config) + self.assertIsNotNone(mmbt) + # Make sure that the config is created from MMBT.Config + self.assertEqual(mmbt.config, config)