Skip to content

Commit

Permalink
Add embedding weights FP16 supports in dense TBE
Browse files Browse the repository at this point in the history
Summary: For data parallelism dense TBE evaluation. Reduce memory pressure and avoid OOM.

Differential Revision: D39712548

fbshipit-source-id: c3193f0ef3035136ca938537b699c8a53889e489
  • Loading branch information
jianyuh authored and facebook-github-bot committed Sep 21, 2022
1 parent c803cb9 commit 9b0f0cd
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
7 changes: 7 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,13 +1350,19 @@ def __init__(
self,
embedding_specs: List[Tuple[int, int]], # tuple of (rows, dims)
feature_table_map: Optional[List[int]] = None, # [T]
weights_precision: SparseType = SparseType.FP32,
output_dtype: SparseType = SparseType.FP32,
pooling_mode: PoolingMode = PoolingMode.SUM,
use_cpu: bool = False,
) -> None: # noqa C901 # tuple of (rows, dims,)
super(DenseTableBatchedEmbeddingBagsCodegen, self).__init__()

self.weights_precision = weights_precision
self.output_dtype: int = output_dtype.as_int()
self.pooling_mode = pooling_mode

table_embedding_dtype = weights_precision.as_dtype()

self.use_cpu = use_cpu
# pyre-fixme[8]: Attribute has type `device`; used as `Union[int, device]`.
self.current_device: torch.device = (
Expand Down Expand Up @@ -1406,6 +1412,7 @@ def __init__(
torch.randn(
weights_offsets[-1],
device=self.current_device,
dtype=table_embedding_dtype,
)
)
for feature in range(T):
Expand Down
3 changes: 1 addition & 2 deletions fbgemm_gpu/test/split_table_batched_embeddings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1292,9 +1292,8 @@ def test_backward_dense(
embedding_specs=[(E, D) for (E, D) in zip(Es, Ds)],
pooling_mode=pooling_mode,
use_cpu=use_cpu,
weights_precision=weights_precision,
)
if weights_precision == SparseType.FP16 and not use_cpu:
cc = cc.half()
if do_pooling:
# NOTE: test TorchScript-compatible!
cc = torch.jit.script(cc)
Expand Down

0 comments on commit 9b0f0cd

Please sign in to comment.