Skip to content

Commit

Permalink
one_embedding add doc string (#7902)
Browse files Browse the repository at this point in the history
* add doc string

* add example

* add

* fix doc

* refine

* address review

* mb to MB

* add make_table_option

* option to options

* refine

* add forward

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and xiacijie committed Apr 24, 2022
1 parent 58a6246 commit c3a3f0c
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ OneFlow API Reference
utils
env
comm
one_embedding



Expand Down
18 changes: 18 additions & 0 deletions docs/source/one_embedding.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
oneflow.one_embedding
===================================
OneFlow one_embedding operations.
----------------------------------
.. currentmodule:: oneflow.one_embedding
.. autoclass:: MultiTableEmbedding
:members: forward,
save_snapshot,
load_snapshot,

.. autofunction:: oneflow.one_embedding.MultiTableEmbedding.forward
.. autofunction:: oneflow.one_embedding.make_device_mem_store_options
.. autofunction:: oneflow.one_embedding.make_cached_ssd_store_options
.. autofunction:: oneflow.one_embedding.make_cached_host_mem_store_options
.. autofunction:: oneflow.one_embedding.make_uniform_initializer
.. autofunction:: oneflow.one_embedding.make_normal_initializer
.. autofunction:: oneflow.one_embedding.make_table_options
.. autofunction:: oneflow.one_embedding.make_table
231 changes: 230 additions & 1 deletion python/oneflow/one_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,81 @@ def _check_cache(cache):


class MultiTableEmbedding(Module):
r"""MultiTableEmbedding represent multi Embedding tables with same embedding_dim, dtype, and key_type.
Args:
name (str): The name of Embedding
embedding_dim (int): the size of each embedding vector
dtype (flow.dtype): the data type of embeddings
key_type (flow.dtype): the data type of feature ids
tables (list): list of table param which can be made by flow.one_embedding.make_table_options
store_options (dict): store option of Embedding
default_initializer (dict, optional): if tables param is None, use default_initializer to initialize table. Defaults to None.
For example:
.. code-block:: python
>>> import oneflow as flow
>>> import numpy as np
>>> import oneflow.nn as nn
>>> # a simple example with 3 table
>>> table_size_array = [39884407, 39043, 17289]
>>> vocab_size = sum(table_size_array)
>>> num_tables = len(table_size_array)
>>> embedding_size = 128
>>> scales = np.sqrt(1 / np.array(table_size_array))
>>> tables = [
>>> flow.one_embedding.make_table_options(
>>> flow.one_embedding.make_uniform_initializer(low=-scale, high=scale)
>>> )
>>> for scale in scales
>>> ]
>>> store_options = flow.one_embedding.make_cached_ssd_store_options(
>>> cache_budget_mb=8192, persistent_path="/your_path_to_ssd", capacity=vocab_size,
>>> )
>>> embedding = flow.one_embedding.MultiTableEmbedding(
>>> name="my_embedding",
>>> embedding_dim=embedding_size,
>>> dtype=flow.float,
>>> key_type=flow.int64,
>>> tables=tables,
>>> store_options=store_options,
>>> )
>>> embedding.to("cuda")
>>> mlp = flow.nn.FusedMLP(
>>> in_features=embedding_size * num_tables,
>>> hidden_features=[512, 256, 128],
>>> out_features=1,
>>> skip_final_activation=True,
>>> )
>>> mlp.to("cuda")
>>>
>>> class TrainGraph(flow.nn.Graph):
>>> def __init__(self,):
>>> super().__init__()
>>> self.embedding_lookup = embedding
>>> self.mlp = mlp
>>> self.add_optimizer(
>>> flow.optim.SGD(self.embedding_lookup.parameters(), lr=0.1, momentum=0.0)
>>> )
>>> self.add_optimizer(
>>> flow.optim.SGD(self.mlp.parameters(), lr=0.1, momentum=0.0)
>>> )
>>> def build(self, ids):
>>> embedding = self.embedding_lookup(ids)
>>> loss = self.mlp(flow.reshape(embedding, (-1, num_tables * embedding_size)))
>>> loss = loss.sum()
>>> loss.backward()
>>> return loss
>>> ids = np.random.randint(0, 1000, (100, num_tables), dtype=np.int64)
>>> ids_tensor = flow.tensor(ids, requires_grad=False).to("cuda")
>>> graph = TrainGraph()
>>> loss = graph(ids_tensor)
>>> print(loss)
"""

def __init__(
self,
name,
Expand Down Expand Up @@ -194,12 +269,50 @@ def _load_from_state_dict(
)

def save_snapshot(self, snapshot_name):
"""save snapshot
Args:
snapshot_name (str): the snapshot_name, snapshot will be saved in the snapshots dir under your_configed_persistent_path
For example:
.. code-block:: python
>>> import oneflow as flow
>>> # use embedding create by flow.one_embedding.MultiTableEmbedding
>>> embedding.save_snapshot("my_snapshot1")
>>> # a snapshot named "my_snapshot1" have been saved in the "snapshots" dir under your_configed_persistent_path
>>> # which can be reload by flow.one_embedding.load_snapshot
"""
self.handler.SaveSnapshot(snapshot_name)

def load_snapshot(self, snapshot_name):
"""load snapshot
Args:
snapshot_name (str): the snapshot_name, snapshot will be load from your_configed_persistent_path
For example:
.. code-block:: python
>>> import oneflow as flow
>>> # use embedding create by flow.one_embedding.MultiTableEmbedding
>>> embedding.load_snapshot("my_snapshot1")
>>> # load a snapshot named "my_snapshot1" from your_configed_persistent_path
"""
self.handler.LoadSnapshot(snapshot_name)

def forward(self, ids, table_ids=None):
"""forward of MultiTableEmbedding
Args:
ids (flow.tensor): the feature ids
table_ids (flow.tensor, optional): the table_id of each id, must be same shape as ids. There is no need to pass table_ids, if has config only one table or the ids has shape (batch_size, num_tables), and each column's id belongs to the column_id th table, otherwise, you should pass the tensor_ids.
Returns:
flow.tensor: the result of embedding lookup
"""
assert self.key_type == ids.dtype, "ids data_type must equals key_type"
return flow._C.one_embedding_lookup(
self.shadow,
Expand All @@ -216,6 +329,20 @@ def forward(self, ids, table_ids=None):
def make_device_mem_store_options(
persistent_path, capacity, size_factor=1, physical_block_size=512
):
"""make GPU only store_options param of MultiTableEmbedding
Args:
persistent_path (str, list): persistent storage path of Embedding. If passed a str, current rank Embedding will be saved in path/rank_id-num_ranks path. If passed a list, the list length must equals num_ranks, each elem of list represent the path of rank_id Embedding.
capacity (int): total capacity of Embedding
size_factor (int, optional): store size factor of embedding_dim, if SGD update, and momentum = 0, should be 1, if momentum > 0, it should be 2. if Adam, should be 3. Defaults to 1.
physical_block_size (int, optional): physical_block_size should be sector size. Defaults to 512.
Returns:
dict: GPU only store_options param of MultiTableEmbedding
See also :func:`oneflow.one_embedding.make_cached_ssd_store_options`
"""

assert isinstance(persistent_path, (str, list, tuple))
assert capacity > 0
options = {
Expand Down Expand Up @@ -245,6 +372,29 @@ def make_cached_ssd_store_options(
size_factor=1,
physical_block_size=512,
):
"""make SSD use GPU as cache store_options param of MultiTableEmbedding
Args:
cache_budget_mb (int): the MB budget of per GPU as cache.
persistent_path (str, list): persistent storage path of Embedding, must use fast SSD because of frequently random disk access during training. If passed a str, current rank Embedding will be saved in path/rank_id-num_ranks path. If passed a list, the list length must equals num_ranks, each elem of list represent the path of rank_id Embedding.
capacity (int): total capacity of Embedding
size_factor (int, optional): store size factor of embedding_dim, if SGD update, and momentum = 0, should be 1, if momentum > 0, it should be 2. if Adam, should be 3. Defaults to 1.
physical_block_size (int, optional): physical_block_size should be sector size. Defaults to 512.
Returns:
dict: SSD use GPU as cache store_options param of MultiTableEmbedding
For example:
.. code-block:: python
>>> import oneflow as flow
>>> store_options = flow.one_embedding.make_cached_ssd_store_options(
>>> cache_budget_mb=8192, persistent_path="/your_path_to_ssd", capacity=vocab_size,
>>> )
>>> # pass the store_options to the "store_options" param of flow.one_embedding.MultiTableEmbedding
>>> # ...
"""
assert isinstance(persistent_path, (str, list, tuple))
assert cache_budget_mb > 0
if capacity is not None:
Expand Down Expand Up @@ -274,6 +424,20 @@ def make_cached_ssd_store_options(
def make_cached_host_mem_store_options(
cache_budget_mb, persistent_path, capacity, size_factor=1, physical_block_size=512,
):
"""make host use GPU as cache store_options param of MultiTableEmbedding
Args:
cache_budget_mb (int): the MB budget of per GPU as cache.
persistent_path (str, list): persistent storage path of Embedding. If passed a str, current rank Embedding will be saved in path/rank_id-num_ranks path. If passed a list, the list length must equals num_ranks, each elem of list represent the path of rank_id Embedding.
capacity (int): total capacity of Embedding
size_factor (int, optional): store size factor of embedding_dim, if SGD update, and momentum = 0, should be 1, if momentum > 0, it should be 2. if Adam, should be 3. Defaults to 1.
physical_block_size (int, optional): physical_block_size should be sector size. Defaults to 512.
Returns:
dict: host use GPU as cache store_options param of MultiTableEmbedding
See also :func:`oneflow.one_embedding.make_cached_ssd_store_options`
"""
assert isinstance(persistent_path, (str, list, tuple))
assert cache_budget_mb > 0
assert capacity > 0
Expand Down Expand Up @@ -303,12 +467,77 @@ def make_cached_host_mem_store_options(


def make_uniform_initializer(low, high):
"""make uniform initializer param of make_table_options
Args:
low (float): A python scalar. Lower bound of the range of random values to generate.
high (float): A python scalar. Upper bound of the range of random values to generate.
Returns:
dict: initializer param of make_table_options
For example:
.. code-block:: python
>>> import oneflow as flow
>>> initializer = flow.one_embedding.make_uniform_initializer(low=-scale, high=scale)
>>> # pass the initializer to flow.one_embedding.make_table_options
>>> # ...
"""
return {"type": "uniform", "low": low, "high": high}


def make_normal_initializer(mean, std):
"""make normal initializer param of make_table_options
Args:
mean (float): A python scalar. Mean of the random values to generate.
std (float): A python scalar. Standard deviation of the random values to generate.
Returns:
dict: initializer param of make_table_options
For example:
.. code-block:: python
>>> import oneflow as flow
>>> initializer = flow.one_embedding.make_normal_initializer(mean=0, std=0.01)
>>> # pass the initializer to flow.one_embedding.make_table_options
>>> # ...
"""
return {"type": "normal", "mean": mean, "std": std}


def make_table(initializer):
def make_table_options(initializer):
"""make table param of MultiTableEmbedding tables
Args:
initializer (dict): initializer param, make by make_uniform_initializer or make_normal_initializer
Returns:
dict: table param of MultiTableEmbedding tables
For example:
.. code-block:: python
>>> import oneflow as flow
>>> initializer = flow.one_embedding.make_uniform_initializer(low=-scale, high=scale)
>>> table1 = flow.one_embedding.make_table_options(initializer)
>>> table2 = flow.one_embedding.make_table_options(initializer)
>>> tables = [table1, table2]
>>> # pass the tables to the "tables" param of flow.one_embedding.MultiTableEmbedding
>>> # ...
"""
return {"initializer": initializer}


def make_table(initializer):
"""alias of `oneflow.one_embedding.make_table_options`
See also :func:`oneflow.one_embedding.make_table_options`
"""
return make_table_options(initializer)

0 comments on commit c3a3f0c

Please sign in to comment.