Skip to content

Commit

Permalink
Optimize the cache fetch for forward split, pt. 2B (pytorch#2282)
Browse files Browse the repository at this point in the history
Summary:

This follows up the work on D51865590 and D52679387 by plumbing the `uvm_cache_stats` argument passing up to the Python API level.  `local_uvm_cache_stats` is now zeroed out before the prefetch step as opposed to after, to allow for the data to be passed into the forward step.

This is a re-attempt of landing D51995949 with additions copied from D52670550

Differential Revision: D53033916
  • Loading branch information
q10 authored and facebook-github-bot committed Feb 5, 2024
1 parent 7889f64 commit 7bbb442
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 1 deletion.
1 change: 1 addition & 0 deletions fbgemm_gpu/codegen/lookup_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class CommonArgs(NamedTuple):
indice_weights: Optional[torch.Tensor]
feature_requires_grad: Optional[torch.Tensor]
lxu_cache_locations: torch.Tensor
uvm_cache_stats: Optional[torch.Tensor]
output_dtype: int
vbe_metadata: VBEMetadata
is_experimental: bool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def invoke(
indice_weights=common_args.indice_weights,
feature_requires_grad=common_args.feature_requires_grad,
lxu_cache_locations=common_args.lxu_cache_locations,
uvm_cache_stats=common_args.uvm_cache_stats,
# VBE metadata
B_offsets=vbe_metadata.B_offsets,
vbe_output_offsets_feature_rank=vbe_metadata.output_offsets_feature_rank,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,11 @@ def forward( # noqa: C901
indice_weights=per_sample_weights,
feature_requires_grad=feature_requires_grad,
lxu_cache_locations=self.lxu_cache_locations,
# Pass the local_uvm_cache_stats bc only that information is
# relevant for the current iteration
uvm_cache_stats=self.local_uvm_cache_stats
if self.gather_uvm_cache_stats
else None,
output_dtype=self.output_dtype,
vbe_metadata=vbe_metadata,
is_experimental=self.is_experimental,
Expand Down Expand Up @@ -1206,6 +1211,12 @@ def _prefetch(self, indices: Tensor, offsets: Tensor) -> None:
if not self.lxu_cache_weights.numel():
return

# Clear the local_uvm_cache_stats before the prefetch instead of after
# the prefetch step, since it will be used in the CommonArgs in the
# forward step
if self.gather_uvm_cache_stats:
self.local_uvm_cache_stats.zero_()

linear_cache_indices = torch.ops.fbgemm.linearize_cache_indices(
self.cache_hash_size_cumsum,
indices,
Expand Down Expand Up @@ -1287,7 +1298,6 @@ def _prefetch(self, indices: Tensor, offsets: Tensor) -> None:
self.uvm_cache_stats = torch.add(
self.uvm_cache_stats, self.local_uvm_cache_stats
)
self.local_uvm_cache_stats.zero_()

def _prefetch_tensors_record_stream(
self, forward_stream: torch.cuda.Stream
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ def forward(
indice_weights=per_sample_weights,
feature_requires_grad=feature_requires_grad,
lxu_cache_locations=lxu_cache_locations,
uvm_cache_stats=None,
vbe_metadata=invokers.lookup_args.VBEMetadata(
B_offsets=None,
output_offsets_feature_rank=None,
Expand Down

0 comments on commit 7bbb442

Please sign in to comment.