Skip to content

Commit

Permalink
[feat] Add init dataclasses for mmbt and encoders (facebookresearch#565)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#565

As step one of FAIM integration, we allow building our models from config so that the models are purely decoupled from configuration and users know what the model expects.

We first do this for the MMBT model.

- Adds configs for MMBT and respective encoders.
- Also adds from_params method for MMBT
- Updates build method to support passing of direct config object
- Add Config class to BaseModel as well and update typings
There is an issue with OmegaConf that doesn't let us use Union in structured configs. Take a look at omry/omegaconf#144

Differential Revision: D23699688

fbshipit-source-id: 3f0f43938814a1c6bfdef51c048bc7ee8ef2c28b
  • Loading branch information
apsdehal authored and facebook-github-bot committed Oct 7, 2020
1 parent c55f821 commit 6b271de
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 28 deletions.
1 change: 1 addition & 0 deletions mmf/configs/models/mmbt/defaults.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ model_config:
freeze_modal: false
freeze_complete_base: false
finetune_lr_multiplier: 1
fused_feature_only: false
# Dimension of the embedding finally returned by the modal encoder
modal_hidden_size: 2048
# Dimension of the embedding finally returned by the text encoder
Expand Down
13 changes: 12 additions & 1 deletion mmf/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,15 @@ def forward(self, sample_list):
import collections
import warnings
from copy import deepcopy
from dataclasses import dataclass
from typing import Union

from mmf.common.registry import registry
from mmf.common.sample import to_device
from mmf.modules.losses import Losses
from mmf.utils.checkpoint import load_pretrained_model
from mmf.utils.download import download_pretrained_model
from omegaconf import MISSING, DictConfig, OmegaConf
from torch import nn


Expand All @@ -63,8 +66,16 @@ class BaseModel(nn.Module):
"""

def __init__(self, config):
@dataclass
class Config:
# Name of the model that is used in registry
model: str = MISSING

def __init__(self, config: Union[DictConfig, Config]):
super().__init__()
if not isinstance(config, DictConfig) and isinstance(config, self.Config):
config = OmegaConf.structured(config)

self.config = config
self._logged_warning = {"losses_present": False}
self._is_pretrained = False
Expand Down
2 changes: 1 addition & 1 deletion mmf/models/m4c.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _build_ocr_encoding(self):

# OCR appearance feature: Faster R-CNN
self.ocr_faster_rcnn_fc7 = build_image_encoder(
self._build_encoder_config, direct_features=True
self._build_encoder_config(), direct_features=True
)
self.finetune_modules.append(
{"module": self.ocr_faster_rcnn_fc7, "lr_scale": self.config.lr_scale_frcn}
Expand Down
74 changes: 65 additions & 9 deletions mmf/models/mmbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,34 @@

import os
from copy import deepcopy
from typing import Dict, Optional
from dataclasses import dataclass
from typing import Dict, Optional, Union

import torch
from mmf.common.registry import registry
from mmf.models.base_model import BaseModel
from mmf.models.interfaces.mmbt import MMBTGridHMInterface
from mmf.modules.encoders import MultiModalEncoderBase
from mmf.modules.encoders import (
Encoder,
ImageEncoder,
ImageEncoderTypes,
MultiModalEncoderBase,
ResNet152ImageEncoder,
TextEncoder,
TextEncoderTypes,
TransformerEncoder,
)
from mmf.modules.hf_layers import replace_with_jit
from mmf.utils.checkpoint import load_pretrained_model
from mmf.utils.configuration import get_mmf_cache_dir
from mmf.utils.modeling import get_optimizer_parameters_for_bert
from omegaconf import OmegaConf
from omegaconf import II, DictConfig, OmegaConf
from torch import Tensor, nn
from transformers.modeling_bert import BertForPreTraining, BertPredictionHeadTransform


# TODO: Remove after transformers package upgrade to 2.5
class MMBTConfig:
class MMBTConfigForTransformers:
"""Configuration class to store the configuration of a `MMBT Model`.
Args:
config (:obj:`~transformers.PreTrainedConfig`):
Expand Down Expand Up @@ -314,7 +324,7 @@ def build(self):
text_encoder, modal_encoder = encoders[0], encoders[1]
self._encoder_config = text_encoder.config

self._mmbt_config = MMBTConfig(
self._mmbt_config = MMBTConfigForTransformers(
self._encoder_config,
num_labels=self.config.num_labels,
modal_hidden_size=self.config.modal_hidden_size,
Expand All @@ -328,7 +338,10 @@ def build(self):
def forward(self, sample_list: Dict[str, Tensor]):

if self._is_direct_features_input:
input_modal = sample_list["image_feature_0"]
if "input_modal" in sample_list:
input_modal = sample_list["input_modal"]
else:
input_modal = sample_list["image_feature_0"]
else:
input_modal = sample_list["image"]

Expand All @@ -352,14 +365,15 @@ def forward(self, sample_list: Dict[str, Tensor]):
# If max_id is greater than 0, that means text is at 0 segment
# which means modal will be at 1
# In other case, it will be zero, which it already is
if max_id == 0:
# NOTE: We compare with tensor here due to TorchScript compliance
if max_id == torch.tensor(0, dtype=max_id.dtype):
token_value = 1
else:
max_segment = self.num_max_segment - 1
# If max id is not equal to max_segment, it means
# text segments start from 0 which means modal will
# be last, otherwise, it is 0, which it already is
if max_id != max_segment:
if max_id != torch.tensor(max_segment, dtype=max_id.dtype):
token_value = max_segment
modal_token_type_ids = torch.full(
(input_modal.size(0), 1),
Expand All @@ -368,6 +382,10 @@ def forward(self, sample_list: Dict[str, Tensor]):
device=input_modal.device,
)

# In case of XRAY, there might be only two dims
if input_modal.dim() == 2:
input_modal = input_modal.unsqueeze(dim=1)

# See details of inputs at
# https://github.com/huggingface/transformers/blob/1789c7/src/transformers/modeling_mmbt.py#L101 # noqa
output = self.mmbt(
Expand Down Expand Up @@ -474,6 +492,7 @@ def __init__(self, config, *args, **kwargs):
self.num_labels = self.config.num_labels
self.output_hidden_states = self.encoder_config.output_hidden_states
self.output_attentions = self.encoder_config.output_attentions
self.fused_feature_only = self.config.fused_feature_only

self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob)
self.classifier = nn.Sequential(
Expand All @@ -495,6 +514,11 @@ def forward(self, sample_list: Dict[str, Tensor]):
), "output_attentions or output_hidden_states not supported in script mode"

pooled_output = self.dropout(pooled_output)

if self.fused_feature_only:
output["fused_feature"] = self.classifier[0](pooled_output)
return output

logits = self.classifier(pooled_output)
reshaped_logits = logits.contiguous().view(-1, self.num_labels)
output["scores"] = reshaped_logits
Expand All @@ -504,9 +528,41 @@ def forward(self, sample_list: Dict[str, Tensor]):

@registry.register_model("mmbt")
class MMBT(BaseModel):
def __init__(self, config):
@dataclass
class Config(BaseModel.Config):
model: str = "mmbt"
# classification or pretraining
training_head_type: str = "pretraining"
bert_model_name: str = "bert-base-uncased"
direct_features_input: bool = False
freeze_text: bool = False
freeze_modal: bool = False
freeze_complete_base: bool = False
finetune_lr_multiplier: float = 1
# Dimension of the embedding finally returned by the modal encoder
modal_hidden_size: int = 2048
text_hidden_size: int = 768
num_labels: int = 2
# This actually is Union[ImageEncoderConfig, ImageFeatureEncoderConfig]
modal_encoder: Encoder.Config = ImageEncoder.Config(
type=ImageEncoderTypes.resnet152, params=ResNet152ImageEncoder.Config()
)
text_encoder: Encoder.Config = TextEncoder.Config(
type=TextEncoderTypes.transformer,
params=TransformerEncoder.Config(bert_model_name=II("bert_model_name")),
)
use_modal_start_token: bool = True
use_modal_end_token: bool = True
fused_feature_only: bool = False
output_dim: int = 768

def __init__(self, config: Union[DictConfig, Config], *args, **kwargs):
super().__init__(config)

@classmethod
def from_params(cls, **kwargs):
return MMBT(OmegaConf.structured(cls.Config(**kwargs)))

def build(self):
if self.config.training_head_type == "pretraining":
self.model = MMBTForPreTraining(self.config)
Expand Down
121 changes: 110 additions & 11 deletions mmf/modules/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import os
import pickle
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum

import torch
import torchvision
Expand All @@ -12,16 +14,49 @@
from mmf.utils.download import download_pretrained_model
from mmf.utils.file_io import PathManager
from mmf.utils.general import get_absolute_path
from omegaconf import OmegaConf
from omegaconf import MISSING, OmegaConf
from torch import nn
from transformers.configuration_auto import AutoConfig
from transformers.modeling_auto import AutoModel


class ImageFeatureEncoder(nn.Module):
def __init__(self, config, *args, **kwargs):
# SuperClass for encoder params
@dataclass
class EncoderParams:
pass


class Encoder(nn.Module):
@dataclass
class Config:
type: str = MISSING
params: EncoderParams = EncoderParams()


class ImageFeatureEncoderTypes(Enum):
default = "default"
identity = "identity"
projection = "projection"
frcnn_fc7 = "finetune_faster_rcnn_fpn_fc7"


@dataclass
class ImageFeatureEncoderBaseParams(EncoderParams):
in_dim: int = MISSING


class ImageFeatureEncoder(Encoder):
@dataclass
class Config(Encoder.Config):
type: ImageFeatureEncoderTypes = MISSING
params: EncoderParams = ImageFeatureEncoderBaseParams()

def __init__(self, config: Config, *args, **kwargs):
super().__init__()
encoder_type = config.type
if isinstance(encoder_type, ImageFeatureEncoderTypes):
encoder_type = encoder_type.value

assert (
"in_dim" in config.params
), "ImageFeatureEncoder require 'in_dim' param in config"
Expand Down Expand Up @@ -81,13 +116,28 @@ def forward(self, image):
return i3


class ImageEncoder(nn.Module):
def __init__(self, config, *args, **kwargs):
class ImageEncoderTypes(Enum):
default = "default"
identity = "identity"
resnet152 = "resnet152"


class ImageEncoder(Encoder):
@dataclass
class Config(Encoder.Config):
type: ImageEncoderTypes = MISSING
params: EncoderParams = EncoderParams()

def __init__(self, config: Config, *args, **kwargs):
super().__init__()
self._type = config.type

if isinstance(self._type, ImageEncoderTypes):
self._type = self._type.value

params = config.params

if self._type == "default":
if self._type == "default" or self._type == "identity":
self.module = nn.Identity()
self.module.out_dim = params.in_dim
elif self._type == "resnet152":
Expand All @@ -105,7 +155,14 @@ def forward(self, image):

# Taken from facebookresearch/mmbt with some modifications
class ResNet152ImageEncoder(nn.Module):
def __init__(self, config, *args, **kwargs):
@dataclass
class Config(EncoderParams):
pretrained: bool = True
# "avg" or "adaptive"
pool_type: str = "avg"
num_output_features: int = 1

def __init__(self, config: Config, *args, **kwargs):
super().__init__()
self.config = config
model = torchvision.models.resnet152(pretrained=config.get("pretrained", True))
Expand Down Expand Up @@ -140,10 +197,24 @@ def forward(self, x):
return out # BxNx2048


class TextEncoderTypes(Enum):
identity = "identity"
transformer = "transformer"
embedding = "embedding"


class TextEncoder(nn.Module):
def __init__(self, config, *args, **kwargs):
@dataclass
class Config(Encoder.Config):
# identity, transformer or embedding as of now
type: TextEncoderTypes = MISSING
params: EncoderParams = EncoderParams()

def __init__(self, config: Config, *args, **kwargs):
super().__init__()
self._type = config.type
if isinstance(self._type, TextEncoderTypes):
self._type = self._type.value

if self._type == "identity":
self.module = nn.Identity()
Expand Down Expand Up @@ -182,7 +253,26 @@ def forward(self, x):


class TransformerEncoder(nn.Module):
def __init__(self, config, *args, **kwargs):
@dataclass
class Config(EncoderParams):
num_segments: int = 2
bert_model_name: str = "bert-base-uncased"
# Options below can be overridden to update the bert configuration used
# to initialize the bert encoder. If some option is missing or
# if you are using an encoder different then BERT, add extra parameters
# by inheriting and extending this config
# Those options will automatically override the options for your transformer
# encoder's configuration. For e.g. vocab_size is missing here, just add
# vocab_size: x to update the size of the vocabulary with which encoder is
# initialized. If you update the default values, the transformer you
# will get will be initialized from scratch.
hidden_size: int = 768
num_hidden_layers: int = 12
num_attention_heads: int = 12
output_attentions: bool = False
output_hidden_states: bool = False

def __init__(self, config: Config, *args, **kwargs):
super().__init__()
self.config = config
self.module = AutoModel.from_pretrained(
Expand All @@ -209,7 +299,7 @@ def _init_segment_embeddings(self):
)
self.embeddings.token_type_embeddings = new_embeds

def _build_encoder_config(self, config):
def _build_encoder_config(self, config: Config):
return AutoConfig.from_pretrained(
self.config.bert_model_name, **OmegaConf.to_container(self.config)
)
Expand All @@ -222,7 +312,16 @@ def forward(self, *args, **kwargs):
class MultiModalEncoderBase(nn.Module):
__jit_unused_properties__ = ["encoder_config"]

def __init__(self, config, *args, **kwargs):
@dataclass
class Config(EncoderParams):
# This actually is Union[ImageEncoderConfig, ImageFeatureEncoderConfig]
modal_encoder: Encoder.Config = ResNet152ImageEncoder.Config()
text_encoder: TextEncoder.Config = TransformerEncoder.Config()
direct_features_input: bool = False
modal_hidden_size: int = 2048
text_hidden_size: int = 768

def __init__(self, config: Config, *args, **kwargs):
super().__init__()
self.config = config

Expand Down
Loading

0 comments on commit 6b271de

Please sign in to comment.