Skip to content

Commit

Permalink
Cleanup new incremental state API (#1005)
Browse files Browse the repository at this point in the history
Summary:
* Now that we have `FairseqIncrementalState`, we can move `get_incremental_state` and `set_incremental_state` as methods in that class, instead of having the helper functions in `utils.py`. I think this will eventually help with type checking too.
* The incremental ID logic was overly complicated, we can just use `uuid` to generate a unique ID for every instance.
* Add missing `with_incremental_state` to light/dynamic conv modules.
* Add additional unit test: `test_incremental_state_multihead_attention`

Pull Request resolved: fairinternal/fairseq-py#1005

Test Plan:
* unit tests

Also confirmed this matches master:
```
$ python generate.py ~/data/data-bin/wmt16_en_de_bpe32k --path /checkpoint/myleott/s3/models/wmt16.en-de.joined-dict.transformer/model.pt --beam 4 --lenpen 0.6 --remove-bpe --quiet
(...)
2020-01-22 09:53:38 | INFO | fairseq_cli.generate | Generate test with beam=4: BLEU4 = 29.28, 60.8/35.1/22.8/15.3 (BP=0.997, ratio=0.997, syslen=62859, reflen=63078)
```

Reviewed By: cndn

Differential Revision: D19517908

Pulled By: myleott

fbshipit-source-id: a406490e342d0d30a9231bf823d3350999bda4c0
  • Loading branch information
myleott authored and facebook-github-bot committed Jan 27, 2020
1 parent 9f4256e commit 88185fc
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 52 deletions.
49 changes: 35 additions & 14 deletions fairseq/incremental_decoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,48 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from fairseq import utils
from typing import Dict, Optional
import uuid

from torch import Tensor


class FairseqIncrementalState(object):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
init_incremental_state(self)
self.init_incremental_state()

def init_incremental_state(self):
self._incremental_state_id = str(uuid.uuid4())

def _get_full_incremental_state_key(self, key: str) -> str:
return "{}.{}".format(self._incremental_state_id, key)

def get_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
) -> Optional[Dict[str, Optional[Tensor]]]:
"""Helper for getting incremental state for an nn.Module."""
full_key = self._get_full_incremental_state_key(key)
if incremental_state is None or full_key not in incremental_state:
return None
return incremental_state[full_key]

def set_incremental_state(
self,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
value: Dict[str, Optional[Tensor]],
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
full_key = self._get_full_incremental_state_key(key)
incremental_state[full_key] = value
return incremental_state


def with_incremental_state(cls):
cls.__bases__ = (FairseqIncrementalState,) + tuple(b for b in cls.__bases__ if b != FairseqIncrementalState)
return cls


# In most cases we should register incremental states using @with_incremental_state decorator
# instead of calling into this explicitly in initializer.
def init_incremental_state(obj):
obj.module_name = obj.__class__.__name__
utils.INCREMENTAL_STATE_INSTANCE_ID[obj.module_name] = (
utils.INCREMENTAL_STATE_INSTANCE_ID.get(obj.module_name, 0) + 1
)
obj._fairseq_instance_id = utils.INCREMENTAL_STATE_INSTANCE_ID[
obj.module_name
]
16 changes: 9 additions & 7 deletions fairseq/models/fairseq_incremental_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,16 @@ def reorder_incremental_state(self, incremental_state, new_order):
order changes between time steps based on the selection of beams.
"""
seen = set()

def apply_reorder_incremental_state(module):
if module != self and hasattr(module, 'reorder_incremental_state') \
and module not in seen:
for module in self.modules():
if (
module != self
and hasattr(module, 'reorder_incremental_state')
and module not in seen
):
seen.add(module)
module.reorder_incremental_state(incremental_state, new_order)

self.apply(apply_reorder_incremental_state)
result = module.reorder_incremental_state(incremental_state, new_order)
if result is not None:
incremental_state = result

def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children."""
Expand Down
2 changes: 2 additions & 0 deletions fairseq/modules/dynamicconv_layer/dynamicconv_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import dynamicconv_cuda
from fairseq import utils
from fairseq.modules.unfold import unfold1d
from fairseq.incremental_decoding_utils import with_incremental_state


class dynamicconvFunction(Function):
Expand All @@ -33,6 +34,7 @@ def backward(ctx, grad_output):
return grad_input, grad_weights, None


@with_incremental_state
class DynamicconvLayer(nn.Module):
def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions fairseq/modules/lightconv_layer/lightconv_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import lightconv_cuda
from fairseq import utils
from fairseq.incremental_decoding_utils import with_incremental_state


class lightconvFunction(Function):
Expand All @@ -32,6 +33,7 @@ def backward(ctx, grad_output):
return grad_input, grad_weights, None


@with_incremental_state
class LightconvLayer(nn.Module):
def __init__(
self,
Expand Down
23 changes: 10 additions & 13 deletions fairseq/modules/multihead_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def forward(
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
self._set_input_buffer(incremental_state, saved_state)
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
src_len = k.size(1)

Expand Down Expand Up @@ -405,28 +405,25 @@ def reorder_incremental_state(
for k in input_buffer.keys():
if input_buffer[k] is not None:
input_buffer[k] = input_buffer[k].index_select(0, new_order)
self._set_input_buffer(incremental_state, input_buffer)
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
return incremental_state

def _get_input_buffer(
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
) -> Dict[str, Optional[Tensor]]:
empty_dict_annotated: Dict[str, Optional[Tensor]] = {}
if incremental_state is None:
return empty_dict_annotated
full_key = utils._get_full_incremental_state_key(self, "attn_state")
if full_key not in incremental_state:
return empty_dict_annotated
return incremental_state[full_key]
result = self.get_incremental_state(incremental_state, "attn_state")
if result is not None:
return result
else:
empty_result: Dict[str, Optional[Tensor]] = {}
return empty_result

def _set_input_buffer(
self,
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
buffer: Dict[str, Optional[Tensor]],
):
full_key = utils._get_full_incremental_state_key(
self, "attn_state"
)
incremental_state[full_key] = buffer
return self.set_incremental_state(incremental_state, "attn_state", buffer)

def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
return attn_weights
Expand Down
24 changes: 6 additions & 18 deletions fairseq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,39 +61,27 @@ def _move_to_cuda(tensor):
return apply_to_sample(_move_to_cuda, sample)


INCREMENTAL_STATE_INSTANCE_ID = {}


def _get_full_incremental_state_key(
module_instance: MultiheadAttention, key: str
) -> str:
return "{}.{}.{}".format(
module_instance.module_name, module_instance._fairseq_instance_id, key
)


def get_incremental_state(
module: MultiheadAttention,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
) -> Optional[Dict[str, Optional[Tensor]]]:
"""Helper for getting incremental state for an nn.Module."""
full_key = _get_full_incremental_state_key(module, key)
if incremental_state is None or full_key not in incremental_state:
return None
return incremental_state[full_key]
return module.get_incremental_state(incremental_state, key)


def set_incremental_state(
module: MultiheadAttention,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
key: str,
value: Dict[str, Optional[Tensor]],
):
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
"""Helper for setting incremental state for an nn.Module."""
if incremental_state is not None:
full_key = _get_full_incremental_state_key(module, key)
incremental_state[full_key] = value
result = module.set_incremental_state(incremental_state, key, value)
if result is not None:
incremental_state = result
return incremental_state


def load_align_dict(replace_unk):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,26 @@


class TestExportModels(unittest.TestCase):

def test_export_multihead_attention(self):
module = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2)
torch.jit.script(module)

def test_incremental_state_multihead_attention(self):
module1 = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2)
module1 = torch.jit.script(module1)
module2 = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2)
module2 = torch.jit.script(module2)

state = {}
state = module1.set_incremental_state(state, 'key', {'a': torch.tensor([1])})
state = module2.set_incremental_state(state, 'key', {'a': torch.tensor([2])})
v1 = module1.get_incremental_state(state, 'key')['a']
v2 = module2.get_incremental_state(state, 'key')['a']

self.assertEqual(v1, 1)
self.assertEqual(v2, 2)

def test_positional_embedding(self):
module = sinusoidal_positional_embedding.SinusoidalPositionalEmbedding(
embedding_dim=8, padding_idx=1
Expand Down

0 comments on commit 88185fc

Please sign in to comment.