Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(pu): add three variants of Bilinear classes and a FiLM class #703

Merged
merged 8 commits into from
Aug 15, 2023
180 changes: 167 additions & 13 deletions ding/torch_utils/network/merge.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,176 @@
"""
This file provides two components for consolidating data streams, SumMerge and VectorMerge.

The following components can be used when we are dealing with data from multiple modes,
This file provides an implementation of several different neural network modules that are used for merging and
transforming input data in various ways. The following components can be used when we are dealing with data from multiple modes,
or when we need to merge multiple intermediate embedded representations in the forward process of a model.

While SumMerge simply sums multiple data streams in the first dimension,
VectorMerge provides three more complex weighted summations.
The main classes defined in this code are:

- BilinearGeneral: This class implements a bilinear transformation layer that applies a bilinear transformation to incoming data,
as described in the "Multiplicative Interactions and Where to Find Them", published at ICLR 2020.
puyuan1996 marked this conversation as resolved.
Show resolved Hide resolved
The transformation involves two input features and an output feature, and also includes an optional bias term.

- TorchBilinearCustomized: This class implements a bilinear layer similar to the one provided by PyTorch (torch.nn.Bilinear),
but with additional customizations. This class can be used as an alternative to the BilinearGeneral class.

- TorchBilinear: This class is a simple wrapper around the PyTorch's built-in nn.Bilinear module. It provides the same
functionality as PyTorch's nn.Bilinear but within the structure of the current module.

- FiLM: This class implements a Feature-wise Linear Modulation (FiLM) layer. FiLM layers apply an affine transformation
to the input data, conditioned on some additional context information.

- GatingType: This is an enumeration class that defines different types of gating mechanisms that can be used in the modules.

- SumMerge: This class provides a simple summing mechanism to merge input streams.

- VectorMerge: This class implements a more complex merging mechanism for vector streams.
The streams are first transformed using layer normalization, a ReLU activation, and a linear layer.
Then they are merged either by simple summing or by using a gating mechanism.

The implementation of these classes involves PyTorch and Numpy libraries, and the classes use PyTorch's nn.Module as the base class,
making them compatible with PyTorch's neural network modules and functionalities.
These modules can be useful building blocks in more complex deep learning architectures.
"""

import enum
from typing import List, Dict
import math
from collections import OrderedDict
from typing import List, Dict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


class BilinearGeneral(nn.Module):
"""
Overview:
Bilinear implementation as in:
Multiplicative Interactions and Where to Find Them, ICLR 2020, https://openreview.net/forum?id=rylnK6VtDH
Arguments:
- in1_features (:obj:`int`): size of each first input sample
- in2_features (:obj:`int`): size of each second input sample
- out_features (:obj:`int`): size of each output sample
"""

def __init__(self, in1_features, in2_features, out_features):
super(BilinearGeneral, self).__init__()
# Initialize the weight matrices W and U, and the bias vectors V and b
self.W = nn.Parameter(torch.Tensor(out_features, in1_features, in2_features))
self.U = nn.Parameter(torch.Tensor(out_features, in2_features))
self.V = nn.Parameter(torch.Tensor(out_features, in1_features))
self.b = nn.Parameter(torch.Tensor(out_features))
self.in1_features = in1_features
self.in2_features = in2_features
self.out_features = out_features
self.reset_parameters()

def reset_parameters(self):
stdv = 1. / np.sqrt(self.in1_features)
self.W.data.uniform_(-stdv, stdv)
self.U.data.uniform_(-stdv, stdv)
self.V.data.uniform_(-stdv, stdv)
self.b.data.uniform_(-stdv, stdv)

def forward(self, x, z):
# Compute the bilinear function
# x^TWz
out_W = torch.einsum('bi,kij,bj->bk', x, self.W, z)
# x^TU
out_U = z.matmul(self.U.t())
# Vz
out_V = x.matmul(self.V.t())
# x^TWz + x^TU + Vz + b
out = out_W + out_U + out_V + self.b
return out


class TorchBilinearCustomized(nn.Module):
"""
Overview:
Customized Torch Bilinear implementation.
Arguments:
- in1_features (:obj:`int`): size of each first input sample
- in2_features (:obj:`int`): size of each second input sample
- out_features (:obj:`int`): size of each output sample
"""

def __init__(self, in1_features, in2_features, out_features):
super(TorchBilinearCustomized, self).__init__()
self.in1_features = in1_features
self.in2_features = in2_features
self.out_features = out_features
self.weight = nn.Parameter(torch.Tensor(out_features, in1_features, in2_features))
self.bias = nn.Parameter(torch.Tensor(out_features))
self.reset_parameters()

def reset_parameters(self):
bound = 1 / math.sqrt(self.in1_features)
nn.init.uniform_(self.weight, -bound, bound)
nn.init.uniform_(self.bias, -bound, bound)

def forward(self, x, z):
# Using torch.einsum for the bilinear operation
out = torch.einsum('bi,oij,bj->bo', x, self.weight, z) + self.bias
return out.squeeze(-1)


class TorchBilinear(nn.Bilinear):
puyuan1996 marked this conversation as resolved.
Show resolved Hide resolved
"""
Overview:
Implementation of the Bilinear layer as in PyTorch:
https://pytorch.org/docs/stable/generated/torch.nn.Bilinear.html#torch.nn.Bilinear
Arguments:
- in1_features (:obj:`int`): size of each first input sample
- in2_features (:obj:`int`): size of each second input sample
- out_features (:obj:`int`): size of each output sample
- bias (:obj:`bool`): If set to False, the layer will not learn an additive bias. Default: ``True``.
"""

def __init__(self, in1_features, in2_features, out_features, bias=True):
super(TorchBilinear, self).__init__(in1_features, in2_features, out_features, bias)

def forward(self, x, z):
return super(TorchBilinear, self).forward(x, z)


class FiLM(nn.Module):
"""
Overview:
Feature-wise Linear Modulation (FiLM) Layer.
This layer applies feature-wise affine transformation based on context.
Arguments:
- feature_dim (:obj:`int`). The dimension of the input feature vector.
- context_dim (:obj:`int`). The dimension of the input context vector.
"""

def __init__(self, feature_dim, context_dim):
super(FiLM, self).__init__()
# Define the fully connected layer for context
# The output dimension is twice the feature dimension for gamma and beta
self.context_layer = nn.Linear(context_dim, 2 * feature_dim)

def forward(self, feature, context):
"""
Overview:
Forward propagation.
Arguments:
- feature (:obj:`torch.Tensor`). The input feature, shape (batch_size, feature_dim)
- context (:obj:`torch.Tensor`). The input context, shape (batch_size, context_dim)
Returns:
- conditioned_feature : torch.Tensor. The output feature after FiLM, shape (batch_size, feature_dim)
"""
# Pass context through the fully connected layer
out = self.context_layer(context)
# Split the output into two parts: gamma and beta
# The dimension for splitting is 1 (feature dimension)
gamma, beta = torch.split(out, out.shape[1] // 2, dim=1)
# Apply feature-wise affine transformation
conditioned_feature = gamma * feature + beta
return conditioned_feature


class GatingType(enum.Enum):
r"""
Overview:
Expand Down Expand Up @@ -67,11 +221,11 @@ class VectorMerge(nn.Module):
"""

def __init__(
self,
input_sizes: Dict[str, int],
output_size: int,
gating_type: GatingType = GatingType.NONE,
use_layer_norm: bool = True,
self,
input_sizes: Dict[str, int],
output_size: int,
gating_type: GatingType = GatingType.NONE,
use_layer_norm: bool = True,
):
r"""
Overview:
Expand Down Expand Up @@ -144,8 +298,8 @@ def encode(self, inputs: Dict[str, Tensor]):
return gates, outputs

def _compute_gate(
self,
init_gate: List[Tensor],
self,
init_gate: List[Tensor],
):
if len(self._input_sizes) == 2:
gate = [self._gating_linears[name](y) for name, y in zip(self._input_sizes.keys(), init_gate)]
Expand Down
129 changes: 129 additions & 0 deletions ding/torch_utils/tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import pytest
import torch
from ding.torch_utils.network.merge import TorchBilinearCustomized, TorchBilinear, BilinearGeneral, FiLM


@pytest.mark.unittest
def test_torch_bilinear_customized():
batch_size = 10
in1_features = 20
in2_features = 30
out_features = 40
bilinear_customized = TorchBilinearCustomized(in1_features, in2_features, out_features)
x = torch.randn(batch_size, in1_features)
z = torch.randn(batch_size, in2_features)
out = bilinear_customized(x, z)
assert out.shape == (batch_size, out_features), "Output shape does not match expected shape."


@pytest.mark.unittest
def test_torch_bilinear():
batch_size = 10
in1_features = 20
in2_features = 30
out_features = 40
torch_bilinear = TorchBilinear(in1_features, in2_features, out_features)
x = torch.randn(batch_size, in1_features)
z = torch.randn(batch_size, in2_features)
out = torch_bilinear(x, z)
assert out.shape == (batch_size, out_features), "Output shape does not match expected shape."


@pytest.mark.unittest
def test_bilinear_consistency():
batch_size = 10
in1_features = 20
in2_features = 30
out_features = 40

# Initialize weights and biases with set values
weight = torch.randn(out_features, in1_features, in2_features)
bias = torch.randn(out_features)

# Create and initialize TorchBilinearCustomized and TorchBilinear models
bilinear_customized = TorchBilinearCustomized(in1_features, in2_features, out_features)
bilinear_customized.weight.data = weight.clone()
bilinear_customized.bias.data = bias.clone()

torch_bilinear = TorchBilinear(in1_features, in2_features, out_features)
torch_bilinear.weight.data = weight.clone()
torch_bilinear.bias.data = bias.clone()

# Provide same input to both models
x = torch.randn(batch_size, in1_features)
z = torch.randn(batch_size, in2_features)

# Compute outputs
out_bilinear_customized = bilinear_customized(x, z)
out_torch_bilinear = torch_bilinear(x, z)

# Compute the mean squared error between outputs
mse = torch.mean((out_bilinear_customized - out_torch_bilinear) ** 2)

print(f"Mean Squared Error between outputs: {mse.item()}")

# Check if outputs are the same
# assert torch.allclose(out_bilinear_customized, out_torch_bilinear), "Outputs of TorchBilinearCustomized and TorchBilinear are not the same."


def test_bilinear_general():
"""
Overview:
Test for the `BilinearGeneral` class.
"""
# Define the input dimensions and batch size
in1_features = 20
in2_features = 30
out_features = 40
batch_size = 10

# Create a BilinearGeneral instance
bilinear_general = BilinearGeneral(in1_features, in2_features, out_features)

# Create random inputs
input1 = torch.randn(batch_size, in1_features)
input2 = torch.randn(batch_size, in2_features)

# Perform forward pass
output = bilinear_general(input1, input2)

# Check output shape
assert output.shape == (batch_size, out_features), "Output shape does not match expected shape."

# Check parameter shapes
assert bilinear_general.W.shape == (
out_features, in1_features, in2_features), "Weight W shape does not match expected shape."
assert bilinear_general.U.shape == (out_features, in2_features), "Weight U shape does not match expected shape."
assert bilinear_general.V.shape == (out_features, in1_features), "Weight V shape does not match expected shape."
assert bilinear_general.b.shape == (out_features,), "Bias shape does not match expected shape."

# Check parameter types
assert isinstance(bilinear_general.W, torch.nn.Parameter), "Weight W is not an instance of torch.nn.Parameter."
assert isinstance(bilinear_general.U, torch.nn.Parameter), "Weight U is not an instance of torch.nn.Parameter."
assert isinstance(bilinear_general.V, torch.nn.Parameter), "Weight V is not an instance of torch.nn.Parameter."
assert isinstance(bilinear_general.b, torch.nn.Parameter), "Bias is not an instance of torch.nn.Parameter."


@pytest.mark.unittest
def test_film_forward():
# Set the feature and context dimensions
feature_dim = 128
context_dim = 256

# Initialize the FiLM layer
film_layer = FiLM(feature_dim, context_dim)

# Create random feature and context vectors
feature = torch.randn((32, feature_dim)) # batch size is 32
context = torch.randn((32, context_dim)) # batch size is 32

# Forward propagation
conditioned_feature = film_layer(feature, context)

# Check the output shape
assert conditioned_feature.shape == feature.shape, \
f'Expected output shape {feature.shape}, but got {conditioned_feature.shape}'

# Check that the output is different from the input
assert not torch.all(torch.eq(feature, conditioned_feature)), \
'The output feature is the same as the input feature'
Loading