Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge pooled emb docstring #3172

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
Jagged Tensor Operators
=======================

.. automodule:: fbgemm_gpu

.. autofunction:: torch.ops.fbgemm.jagged_2d_to_dense

.. autofunction:: torch.ops.fbgemm.jagged_1d_to_dense
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Pooled Embedding Operators
==========================

.. automodule:: fbgemm_gpu

.. autofunction:: torch.ops.fbgemm.merge_pooled_embeddings
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
Table Batched Embedding (TBE) Operators
=======================================
Table Batched Embedding (TBE) Training Module
=============================================

.. autoclass:: fbgemm_gpu.split_table_batched_embeddings_ops_training.SplitTableBatchedEmbeddingBagsCodegen
:members:
:members: forward,
split_embedding_weights,
split_optimizer_states,
set_learning_rate,
update_hyper_parameters,
set_optimizer_step
1 change: 1 addition & 0 deletions fbgemm_gpu/docs/src/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,4 @@ Table of Contents

fbgemm_gpu-python-api/table_batched_embedding_ops.rst
fbgemm_gpu-python-api/jagged_tensor_ops.rst
fbgemm_gpu-python-api/pooled_embedding_ops.rst
2 changes: 1 addition & 1 deletion fbgemm_gpu/fbgemm_gpu/docs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@

# Trigger the manual addition of docstrings to pybind11-generated operators
try:
from . import jagged_tensor_ops, table_batched_embedding_ops # noqa: F401
from . import jagged_tensor_ops, merge_pooled_embedding_ops # noqa: F401
except Exception:
pass
36 changes: 36 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/docs/merge_pooled_embedding_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch

from .common import add_docs

add_docs(
torch.ops.fbgemm.merge_pooled_embeddings,
"""
merge_pooled_embeddings(pooled_embeddings, uncat_dim_size, target_device, cat_dim=1) -> Tensor

Concatenate embedding outputs from different devices (on the same host)
on to the target device.

Args:
pooled_embeddings (List[Tensor]): A list of embedding outputs from
different devices on the same host. Each output has 2
dimensions.

uncat_dim_size (int): The size of the dimension that is not
concatenated, i.e., if `cat_dim=0`, `uncat_dim_size` is the size
of dim 1 and vice versa.

target_device (torch.device): The target device that aggregates all
the embedding outputs.

cat_dim (int = 1): The dimension that the tensors are concatenated

Returns:
The concatenated embedding output (2D) on the target device
""",
)
Loading
Loading