Skip to content

Commit

Permalink
Remove vocab from cuda (facebookresearch#955)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#955

We have users who can't train models on extremely large embeddings because we try to allocate space for that on the GPU.

With this diff, in training, we add a flag which users can set explicitly to keep the embedding layer on CPU even when the model is getting trained on GPUs. This is not default because we need the user to know that there will be a cost associated moving the tensors on and off the GPU.

Note that this only applies during training.

Also note that this does not work in a multi-GPU environment because of the way the weights are synced via NCCL.

Differential Revision: D17114398

fbshipit-source-id: e28b2981fbcbb248a6a704fd3c6e325fd45490e9
  • Loading branch information
snisarg authored and facebook-github-bot committed Sep 24, 2019
1 parent c7dd752 commit 189fdf4
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 2 deletions.
1 change: 1 addition & 0 deletions pytext/config/field_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class WordFeatConfig(ModuleConfig):
min_freq: int = 1
mlp_layer_dims: Optional[List[int]] = []
padding_idx: Optional[int] = None
cpu_only: bool = False


class DictFeatConfig(ModuleConfig):
Expand Down
13 changes: 11 additions & 2 deletions pytext/models/embeddings/word_embedding.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import collections
from typing import List, Optional

import torch
from pytext.config.field_config import WordFeatConfig
from pytext.data.tensorizers import Tensorizer
from pytext.fields import FieldMeta
from pytext.utils.embeddings import PretrainedEmbedding
from pytext.utils.torch import CPUOnlyParameter
from tensorboardX import SummaryWriter
from torch import nn

Expand Down Expand Up @@ -96,6 +96,7 @@ def from_config(
mlp_layer_dims=config.mlp_layer_dims,
padding_idx=config.padding_idx,
vocab=vocab,
cpu_only=config.cpu_only,
)

def __init__(
Expand All @@ -108,6 +109,7 @@ def __init__(
mlp_layer_dims: List[int] = (),
padding_idx: Optional[int] = None,
vocab: Optional[List[str]] = None,
cpu_only: bool = False,
) -> None:
output_embedding_dim = mlp_layer_dims[-1] if mlp_layer_dims else embedding_dim
EmbeddingBase.__init__(self, embedding_dim=output_embedding_dim)
Expand All @@ -119,6 +121,8 @@ def __init__(
_weight=embeddings_weight,
padding_idx=padding_idx,
)
if cpu_only:
self.word_embedding.weight = CPUOnlyParameter(self.word_embedding.weight)
if embeddings_weight is None and init_range:
self.word_embedding.weight.data.uniform_(init_range[0], init_range[1])
# Initialize unk embedding with zeros
Expand All @@ -142,7 +146,12 @@ def __getattr__(self, name):
return super().__getattr__(name)

def forward(self, input):
return self.mlp(self.word_embedding(input))
input_device = input.device
embedding_device = self.word_embedding.weight.device
if input_device != embedding_device:
input = input.to(embedding_device)
# We only want to do the embedding lookup on CPU
return self.mlp(self.word_embedding(input).to(input_device))

def freeze(self):
for param in self.word_embedding.parameters():
Expand Down
13 changes: 13 additions & 0 deletions pytext/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Dict, List, Optional, Tuple

import torch
from pytext.utils import cuda


# ===== the following section should be replaced once JIT provide native support
Expand Down Expand Up @@ -500,3 +501,15 @@ def package_for_inference(self):
self.do_normalization = torch.jit.Attribute(self.do_normalization, bool)
self.feature_avgs = torch.jit.Attribute(self.feature_avgs, List[float])
self.feature_stddevs = torch.jit.Attribute(self.feature_stddevs, List[float])


class CPUOnlyParameter(torch.nn.Parameter):
def __init__(self):
assert (
cuda.DISTRIBUTED_WORLD_SIZE <= 1
), "Multiple GPUs not supported for cpu_only embeddings"
super.__init__()

def cuda(self, device=None):
# We do nothing because this Parameter should only be on the CPU
return self

0 comments on commit 189fdf4

Please sign in to comment.