Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(cy): add tensor stream merging tools #673

Merged
merged 4 commits into from
Jun 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ding/torch_utils/network/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
from .gumbel_softmax import GumbelSoftmax
from .gtrxl import GTrXL, GRUGatingUnit
from .popart import PopArt
from .merge import GatingType, SumMerge, VectorMerge
175 changes: 175 additions & 0 deletions ding/torch_utils/network/merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""
This file provides two components for consolidating data streams, SumMerge and VectorMerge.

The following components can be used when we are dealing with data from multiple modes,
or when we need to merge multiple intermediate embedded representations in the forward process of a model.

While SumMerge simply sums multiple data streams in the first dimension,
VectorMerge provides three more complex weighted summations.
"""

import enum
from typing import List, Dict
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


class GatingType(enum.Enum):
r"""
Overview:
Defines how the tensors are gated and aggregated in modules.
"""
NONE = 'none'
GLOBAL = 'global'
POINTWISE = 'pointwise'


class SumMerge(nn.Module):
r"""
Overview:
Merge streams using a simple sum.
Streams must have the same size.
This module can merge any type of stream (vector, units or visual).
"""

def forward(self, tensors: List[Tensor]):
# stack the tensors along the first dimension
stacked = torch.stack(tensors, dim=0)

# compute the sum along the first dimension
summed = torch.sum(stacked, dim=0)
# summed = sum(tensors)
return summed


class VectorMerge(nn.Module):
r"""
Overview:
Merge vector streams.

Streams are first transformed through layer normalization, relu and linear
layers, then summed, so they don't need to have the same size.
Gating can also be used before the sum.

If gating_type is not none, the sum is weighted using a softmax
of the intermediate activations labelled above.

Sepcifically,
GatingType.NONE:
Means simple addition of streams and the sum is not weighted based on gate features.
GatingType.GLOBAL:
Each data stream is weighted by a global gate value, and the sum of all global gate values is 1.
GatingType.POINTWISE:
Compared to GLOBAL, each value in the data stream feature tensor is weighted.
"""

def __init__(
self,
input_sizes: Dict[str, int],
output_size: int,
gating_type: GatingType = GatingType.NONE,
use_layer_norm: bool = True,
):
r"""
Overview:
Initializes VectorMerge module.
Arguments:
- input_sizes: A dictionary mapping input names to their size (a single
integer for 1d inputs, or None for 0d inputs).
If an input size is None, we assume it's ().
- output_size: The size of the output vector.
- gating_type: The type of gating mechanism to use.
- use_layer_norm: Whether to use layer normalization.
"""
super().__init__()
self._input_sizes = OrderedDict(input_sizes)
self._output_size = output_size
self._gating_type = gating_type
self._use_layer_norm = use_layer_norm

if self._use_layer_norm:
self._layer_norms = nn.ModuleDict()
else:
self._layer_norms = None

self._linears = nn.ModuleDict()
for name, size in self._input_sizes.items():
linear_input_size = size if size > 0 else 1
if self._use_layer_norm:
self._layer_norms[name] = nn.LayerNorm(linear_input_size)
self._linears[name] = nn.Linear(linear_input_size, self._output_size)

self._gating_linears = nn.ModuleDict()
if self._gating_type is GatingType.GLOBAL:
self.gate_size = 1
elif self._gating_type is GatingType.POINTWISE:
self.gate_size = self._output_size
elif self._gating_type is GatingType.NONE:
self._gating_linears = None
else:
raise ValueError(f'Gating type {self._gating_type} is not supported')

if self._gating_linears is not None:
if len(self._input_sizes) == 2:
# more efficient than the general version below
for name, size in self._input_sizes.items():
gate_input_size = size if size > 0 else 1
gating_layer = nn.Linear(gate_input_size, self.gate_size)
torch.nn.init.normal_(gating_layer.weight, std=0.005)
torch.nn.init.constant_(gating_layer.bias, 0.0)
self._gating_linears[name] = gating_layer
else:
for name, size in self._input_sizes.items():
gate_input_size = size if size > 0 else 1
gating_layer = nn.Linear(gate_input_size, len(self._input_sizes) * self.gate_size)
torch.nn.init.normal_(gating_layer.weight, std=0.005)
torch.nn.init.constant_(gating_layer.bias, 0.0)
self._gating_linears[name] = gating_layer

def encode(self, inputs: Dict[str, Tensor]):
gates, outputs = [], []
for name, size in self._input_sizes.items():
feature = inputs[name]
if size <= 0 and feature.dim() == 1:
feature = feature.unsqueeze(-1)
feature = feature.to(torch.float32)
if self._use_layer_norm and name in self._layer_norms:
feature = self._layer_norms[name](feature)
feature = F.relu(feature)
gates.append(feature)
outputs.append(self._linears[name](feature))
return gates, outputs

def _compute_gate(
self,
init_gate: List[Tensor],
):
if len(self._input_sizes) == 2:
gate = [self._gating_linears[name](y) for name, y in zip(self._input_sizes.keys(), init_gate)]
gate = sum(gate)
sigmoid = torch.sigmoid(gate)
gate = [sigmoid, 1.0 - sigmoid]
else:
gate = [self._gating_linears[name](y) for name, y in zip(self._input_sizes.keys(), init_gate)]
gate = sum(gate)
gate = gate.reshape([-1, len(self._input_sizes), self.gate_size])
gate = F.softmax(gate, dim=1)
assert gate.shape[1] == len(self._input_sizes)
gate = [gate[:, i] for i in range(len(self._input_sizes))]
return gate

def forward(self, inputs: Dict[str, Tensor]) -> Tensor:
gates, outputs = self.encode(inputs)
if len(outputs) == 1:
# Special case of 1-D inputs that do not need any gating.
output = outputs[0]
elif self._gating_type is GatingType.NONE:
output = sum(outputs)
else:
gate = self._compute_gate(gates)
data = [g * d for g, d in zip(gate, outputs)]
output = sum(data)
return output
46 changes: 46 additions & 0 deletions ding/torch_utils/network/tests/test_merge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import pytest
import torch
from ding.torch_utils.network import GatingType, SumMerge, VectorMerge


@pytest.mark.unittest
def test_SumMerge():
input_shape = (3, 5)
input = [torch.rand(input_shape).requires_grad_(True) for i in range(4)]
sum_merge = SumMerge()

output = sum_merge(input)
assert output.shape == (3, 5)
loss = output.mean()
loss.backward()
assert isinstance(input[0].grad, torch.Tensor)


@pytest.mark.unittest
def test_VectorMerge():
input_sizes = {'in1': 3, 'in2': 16, 'in3': 27}
output_size = 512
input_dict = {}
for k, v in input_sizes.items():
input_dict[k] = torch.rand((64, v)).requires_grad_(True)

vector_merge = VectorMerge(input_sizes, output_size, GatingType.NONE)
output = vector_merge(input_dict)
assert output.shape == (64, output_size)
loss = output.mean()
loss.backward()
assert isinstance(input_dict['in1'].grad, torch.Tensor)

vector_merge = VectorMerge(input_sizes, output_size, GatingType.GLOBAL)
output = vector_merge(input_dict)
assert output.shape == (64, output_size)
loss = output.mean()
loss.backward()
assert isinstance(input_dict['in1'].grad, torch.Tensor)

vector_merge = VectorMerge(input_sizes, output_size, GatingType.POINTWISE)
output = vector_merge(input_dict)
assert output.shape == (64, output_size)
loss = output.mean()
loss.backward()
assert isinstance(input_dict['in1'].grad, torch.Tensor)