Skip to content

Commit

Permalink
Add embedding weights FP16 supports in dense TBE (#1343)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1343

For data parallelism dense TBE evaluation. Reduce memory pressure and avoid OOM.

Differential Revision: D39712548

fbshipit-source-id: faa28cebb5ba2db60014dff2cb18cd49715db333
  • Loading branch information
jianyuh authored and facebook-github-bot committed Sep 25, 2022
1 parent 6da1b24 commit 6874421
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
5 changes: 5 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,13 +1350,17 @@ def __init__(
self,
embedding_specs: List[Tuple[int, int]], # tuple of (rows, dims)
feature_table_map: Optional[List[int]] = None, # [T]
weights_precision: SparseType = SparseType.FP32,
pooling_mode: PoolingMode = PoolingMode.SUM,
use_cpu: bool = False,
) -> None: # noqa C901 # tuple of (rows, dims,)
super(DenseTableBatchedEmbeddingBagsCodegen, self).__init__()

self.weights_precision = weights_precision
self.pooling_mode = pooling_mode

table_embedding_dtype = weights_precision.as_dtype()

self.use_cpu = use_cpu
# pyre-fixme[8]: Attribute has type `device`; used as `Union[int, device]`.
self.current_device: torch.device = (
Expand Down Expand Up @@ -1406,6 +1410,7 @@ def __init__(
torch.randn(
weights_offsets[-1],
device=self.current_device,
dtype=table_embedding_dtype,
)
)
for feature in range(T):
Expand Down
3 changes: 1 addition & 2 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,9 +1292,8 @@ def test_backward_dense(
embedding_specs=[(E, D) for (E, D) in zip(Es, Ds)],
pooling_mode=pooling_mode,
use_cpu=use_cpu,
weights_precision=weights_precision,
)
if weights_precision == SparseType.FP16 and not use_cpu:
cc = cc.half()
if do_pooling:
# NOTE: test TorchScript-compatible!
cc = torch.jit.script(cc)
Expand Down

0 comments on commit 6874421

Please sign in to comment.