Skip to content

Commit

Permalink
Add bounds check in SSD-TBE
Browse files Browse the repository at this point in the history
Summary: As title

Differential Revision: D61398178
  • Loading branch information
sryap authored and facebook-github-bot committed Aug 19, 2024
1 parent 81efd37 commit 555bc80
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from fbgemm_gpu.runtime_monitor import TBEStatsReporter, TBEStatsReporterConfig
from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType, SparseType
from fbgemm_gpu.split_table_batched_embeddings_ops_common import (
BoundsCheckMode,
CacheAlgorithm,
EmbeddingLocation,
PoolingMode,
Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(
CowClipDefinition
] = None, # used by Rowwise Adagrad
pooling_mode: PoolingMode = PoolingMode.SUM,
bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING,
# Parameter Server Configs
ps_hosts: Optional[Tuple[Tuple[str, int]]] = None,
ps_max_key_per_request: Optional[int] = None,
Expand All @@ -135,6 +137,7 @@ def __init__(
super(SSDTableBatchedEmbeddingBags, self).__init__()

self.pooling_mode = pooling_mode
self.bounds_check_mode_int: int = bounds_check_mode.value
self.embedding_specs = embedding_specs
(rows, dims) = zip(*embedding_specs)
T_ = len(self.embedding_specs)
Expand Down Expand Up @@ -187,6 +190,20 @@ def __init__(
f"TBE will allocate a UVM buffer with is_host_mapped={uvm_host_mapped}"
)

# Buffers for bounds check
self.register_buffer(
"rows_per_table",
torch.tensor(
[rows[t] for t in self.feature_table_map],
device=self.current_device,
dtype=torch.int64,
),
)
self.register_buffer(
"bounds_check_warning",
torch.tensor([0], device=self.current_device, dtype=torch.int64),
)

assert cache_sets > 0
element_size = weights_precision.bit_rate() // 8
assert (
Expand Down Expand Up @@ -1232,6 +1249,19 @@ def forward(
# Force casting per_sample_weights to float
if per_sample_weights is not None:
per_sample_weights = per_sample_weights.float()

if self.bounds_check_mode_int != BoundsCheckMode.NONE.value:
torch.ops.fbgemm.bounds_check_indices(
self.rows_per_table,
indices,
offsets,
self.bounds_check_mode_int,
self.bounds_check_warning,
per_sample_weights,
B_offsets=None,
max_B=-1,
)

if len(self.timesteps_prefetched) == 0:
self.prefetch(indices, offsets)
assert len(self.ssd_prefetch_data) > 0
Expand Down

0 comments on commit 555bc80

Please sign in to comment.