Skip to content

Commit

Permalink
[feat] Add init dataclasses for mmbt and encoders
Browse files Browse the repository at this point in the history
Summary:
As step one of FAIM integration, we allow building our models from config so that the models are purely decoupled from configuration and users know what the model expects.

We first do this for the MMBT model.

- Adds configs for MMBT and respective encoders.
- Also adds from_params method for MMBT

There is an issue with OmegaConf that doesn't let us use Union in structured configs. Take a look at omry/omegaconf#144

Differential Revision: D23699688

fbshipit-source-id: 0cd18791e2cbce454f40088dca4b35443b62b567
  • Loading branch information
apsdehal authored and facebook-github-bot committed Sep 16, 2020
1 parent 044f575 commit f5bb1d3
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 17 deletions.
47 changes: 41 additions & 6 deletions mmf/models/mmbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,32 @@

import os
from copy import deepcopy
from typing import Optional
from dataclasses import dataclass
from typing import Any, Optional

import torch
from mmf.common.registry import registry
from mmf.models.base_model import BaseModel
from mmf.models.interfaces.mmbt import MMBTGridHMInterface
from mmf.modules.encoders import MultiModalEncoderBase
from mmf.modules.encoders import (
EncoderConfig,
MultiModalEncoderBase,
ResNet152ImageEncoderConfig,
TextEncoderConfig,
TransformerEncoderConfig,
TransformerEncoderParams,
)
from mmf.modules.hf_layers import replace_with_jit
from mmf.utils.checkpoint import load_pretrained_model
from mmf.utils.configuration import get_mmf_cache_dir
from mmf.utils.modeling import get_optimizer_parameters_for_bert
from omegaconf import OmegaConf
from omegaconf import II, OmegaConf
from torch import Tensor, nn
from transformers.modeling_bert import BertForPreTraining, BertPredictionHeadTransform


# TODO: Remove after transformers package upgrade to 2.5
class MMBTConfig:
class MMBTConfigForTransformers:
"""Configuration class to store the configuration of a `MMBT Model`.
Args:
config (:obj:`~transformers.PreTrainedConfig`):
Expand Down Expand Up @@ -312,7 +320,7 @@ def build(self):
text_encoder, modal_encoder = encoders[0], encoders[1]
self._encoder_config = text_encoder.config

self._mmbt_config = MMBTConfig(
self._mmbt_config = MMBTConfigForTransformers(
self._encoder_config,
num_labels=self.config.num_labels,
modal_hidden_size=self.config.modal_hidden_size,
Expand Down Expand Up @@ -519,13 +527,40 @@ def forward(
return output


@dataclass
class MMBTConfig:
# classification or pretraining
training_head_type: str = "pretraining"
bert_model_name: str = "bert-base_uncased"
direct_features_input: bool = False
freeze_text: bool = False
freeze_modal: bool = False
freeze_complete_base: bool = False
finetune_lr_multiplier: float = 1
# Dimension of the embedding finally returned by the modal encoder
modal_hidden_size: int = 2048
text_hidden_size: int = 768
num_labels: int = 2
# This actually is Union[ImageEncoderConfig, ImageFeatureEncoderConfig]
modal_encoder: EncoderConfig = ResNet152ImageEncoderConfig()
text_encoder: TextEncoderConfig = TransformerEncoderConfig(
params=TransformerEncoderParams(bert_model_name=II("bert_model_name"))
)
use_modal_start_token: bool = True
use_modal_end_token: bool = True


@registry.register_model("mmbt")
class MMBT(BaseModel):
def __init__(self, config):
def __init__(self, config: MMBTConfig):
super().__init__(config)
# Replace transformer layers with scriptable JIT layers
replace_with_jit()

@classmethod
def from_params(cls, **kwargs):
return MMBT(OmegaConf.structured(MMBTConfig(**kwargs)))

def build(self):
if self.config.training_head_type == "pretraining":
self.model = MMBTForPreTraining(self.config)
Expand Down
135 changes: 126 additions & 9 deletions mmf/modules/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import os
import pickle
from copy import deepcopy
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict

import torch
import torchvision
Expand All @@ -12,14 +15,44 @@
from mmf.utils.download import download_pretrained_model
from mmf.utils.file_io import PathManager
from mmf.utils.general import get_absolute_path
from omegaconf import OmegaConf
from omegaconf import MISSING, OmegaConf
from torch import nn
from transformers.configuration_auto import AutoConfig
from transformers.modeling_auto import AutoModel


class ImageFeatureEncoderTypes(Enum):
default = "default"
identity = "identity"
projection = "projection"
frcnn_fc7 = "finetune_faster_rcnn_fpn_fc7"


# SuperClass for encoder params
@dataclass
class EncoderParams:
pass


@dataclass
class EncoderConfig:
type: str = MISSING
params: EncoderParams = EncoderParams()


@dataclass
class ImageFeatureEncoderBaseParams(EncoderParams):
in_dim: int = MISSING


@dataclass
class ImageFeatureEncoderConfig(EncoderConfig):
type: ImageFeatureEncoderTypes = MISSING
params: EncoderParams = ImageFeatureEncoderBaseParams()


class ImageFeatureEncoder(nn.Module):
def __init__(self, config, *args, **kwargs):
def __init__(self, config: ImageFeatureEncoderConfig, *args, **kwargs):
super().__init__()
encoder_type = config.type
assert (
Expand Down Expand Up @@ -81,13 +114,25 @@ def forward(self, image):
return i3


class ImageEncoderTypes(Enum):
default = "default"
identity = "identity"
resnet152 = "resnet152"


@dataclass
class ImageEncoderConfig(EncoderConfig):
type: ImageFeatureEncoderTypes = MISSING
params: EncoderParams = EncoderParams()


class ImageEncoder(nn.Module):
def __init__(self, config, *args, **kwargs):
def __init__(self, config: ImageEncoderConfig, *args, **kwargs):
super().__init__()
self._type = config.type
params = config.params

if self._type == "default":
if self._type == "default" or self._type == "identity":
self.module = nn.Identity()
self.module.out_dim = params.in_dim
elif self._type == "resnet152":
Expand All @@ -103,9 +148,27 @@ def forward(self, image):
return self.module(image)


@dataclass
class ResNet152ImageEncoderParams(EncoderParams):
pretrained: bool = True
# "avg" or "adaptive"
pool_type: str = "avg"
num_output_features: int = 1


@dataclass
class ResNet152ImageEncoderConfig(ImageEncoderConfig):
"""Use this config when initializing ResNet152ImageEncoder from ImageEncoder
otherwise use ResNet152ImageEncoderParams to initialize it directly.
"""

type: ImageEncoderTypes = ImageEncoderTypes.resnet152
params: ResNet152ImageEncoderParams = ResNet152ImageEncoderParams()


# Taken from facebookresearch/mmbt with some modifications
class ResNet152ImageEncoder(nn.Module):
def __init__(self, config, *args, **kwargs):
def __init__(self, config: ResNet152ImageEncoderParams, *args, **kwargs):
super().__init__()
self.config = config
model = torchvision.models.resnet152(pretrained=config.get("pretrained", True))
Expand Down Expand Up @@ -140,8 +203,21 @@ def forward(self, x):
return out # BxNx2048


class TextEncoderTypes(Enum):
identity = "identity"
transformer = "transformer"
embedding = "embedding"


@dataclass
class TextEncoderConfig(EncoderConfig):
# identity, transformer or embedding as of now
type: TextEncoderTypes = MISSING
params: EncoderParams = EncoderParams()


class TextEncoder(nn.Module):
def __init__(self, config, *args, **kwargs):
def __init__(self, config: TextEncoderConfig, *args, **kwargs):
super().__init__()
self._type = config.type

Expand Down Expand Up @@ -181,8 +257,39 @@ def forward(self, x):
return x.squeeze()


@dataclass
class TransformerEncoderParams(EncoderParams):
num_segments: int = 2
bert_model_name: str = "bert-base-uncased"
# Options below can be overridden to update the bert configuration used
# to initialize the bert encoder. If some option is missing or
# if you are using an encoder different then BERT, add extra parameters
# by inheriting and extending this config
# Those options will automatically override the options for your transformer
# encoder's configuration. For e.g. vocab_size is missing here, just add
# vocab_size: x to update the size of the vocabulary with which encoder is
# initialized. If you update the default values, the transformer you
# will get will be initialized from scratch.
hidden_size: int = 768
num_hidden_layers: int = 12
num_attention_heads: int = 12
output_attentions: bool = False
output_hidden_states: bool = False


@dataclass
class TransformerEncoderConfig(TextEncoderConfig):
"""Use this class to initialize config when initializing the TextEncoder
directly to get access to TransformerEncoder. Otherwise initialize using
TransformerEncoderParams
"""

type: TextEncoderTypes = TextEncoderTypes.transformer
params: TransformerEncoderParams = TransformerEncoderParams()


class TransformerEncoder(nn.Module):
def __init__(self, config, *args, **kwargs):
def __init__(self, config: TransformerEncoderParams, *args, **kwargs):
super().__init__()
self.config = config
self.module = AutoModel.from_pretrained(
Expand All @@ -209,7 +316,7 @@ def _init_segment_embeddings(self):
)
self.embeddings.token_type_embeddings = new_embeds

def _build_encoder_config(self, config):
def _build_encoder_config(self, config: TransformerEncoderParams):
return AutoConfig.from_pretrained(
self.config.bert_model_name, **OmegaConf.to_container(self.config)
)
Expand All @@ -219,8 +326,18 @@ def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)[1]


@dataclass
class MultiModalEncoderBaseConfig:
# This actually is Union[ImageEncoderConfig, ImageFeatureEncoderConfig]
modal_encoder: EncoderConfig = ResNet152ImageEncoderConfig()
text_encoder: TextEncoderConfig = TransformerEncoderConfig()
direct_features_input: bool = False
modal_hidden_size: int = 2048
text_hidden_size: int = 768


class MultiModalEncoderBase(nn.Module):
def __init__(self, config, *args, **kwargs):
def __init__(self, config: MultiModalEncoderBaseConfig, *args, **kwargs):
super().__init__()
self.config = config

Expand Down
52 changes: 50 additions & 2 deletions tests/models/test_mmbt_script.py → tests/models/test_mmbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,18 @@
import io
import unittest

import tests.test_utils as test_utils
import torch
from mmf.common.registry import registry
from mmf.models.mmbt import MMBT, MMBTConfig
from mmf.modules.encoders import (
ResNet152ImageEncoderConfig,
ResNet152ImageEncoderParams,
TextEncoderConfig,
)
from mmf.utils.configuration import Configuration
from mmf.utils.env import setup_imports

import tests.test_utils as test_utils
from omegaconf import OmegaConf


class TestMMBTTorchscript(unittest.TestCase):
Expand Down Expand Up @@ -60,3 +66,45 @@ def test_finetune_model(self):
)

self.assertTrue(torch.equal(model_output["scores"], script_output["scores"]))


class TestMMBTConfig(unittest.TestCase):
def test_mmbt_from_params(self):
# default init
mmbt = MMBT.from_params(
modal_encoder=ResNet152ImageEncoderConfig(
params=ResNet152ImageEncoderParams(pretrained=False)
),
text_encoder=TextEncoderConfig(type="identity"),
)

config = OmegaConf.structured(
MMBTConfig(
modal_encoder=ResNet152ImageEncoderConfig(
params=ResNet152ImageEncoderParams(pretrained=False)
),
text_encoder=TextEncoderConfig(type="identity"),
)
)
self.assertIsNotNone(mmbt)
# Make sure that the config is created from MMBTConfig
self.assertEqual(mmbt.config, config)

@test_utils.skip_if_no_network
def test_mmbt_pretrained(self):
mmbt = MMBT.from_params()
self.assertIsNotNone(mmbt)

def test_mmbt_directly_from_config(self):
config = OmegaConf.structured(
MMBTConfig(
modal_encoder=ResNet152ImageEncoderConfig(
params=ResNet152ImageEncoderParams(pretrained=False)
),
text_encoder=TextEncoderConfig(type="identity"),
)
)
mmbt = MMBT(config)
self.assertIsNotNone(mmbt)
# Make sure that the config is created from MMBTConfig
self.assertEqual(mmbt.config, config)

0 comments on commit f5bb1d3

Please sign in to comment.