Skip to content

Commit

Permalink
Make the scratch pad tensor UVA (pytorch#2844)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2844

Before this diff, the scratch pad in SSD TBE (see D55998215 for more
detail) was a CPU tensor which was later transferred to GPU to allow
the TBE kernels to access it.  The scratch pad tranfer was highly
inefficient since TBE over provisioned the scratch pad buffer
allocation (as it did not know the exact number of cache missed rows)
causing extra data transfer.  Such the extra data transfer could be
large since the number of cache missed rows was normally much smaller
than value that TBE over provisioned.

There are two ways to avoid the extra data transfer:

(1) Let TBE have the exact number of cache missed rows on host which
requires device-to-host data transfer which will cause a sync point
between host and device (not desirable in most trainings).
However, this will allow TBE to use `cudaMemcpy` which will utilize
the DMA engine and will allow the memory copy to overlap efficiently
with other compute kernels.

(2) Make the scratch pad accessible by both CPU and GPU.  In other
words, make the scratch pad a UVA tensor.  This does not require
device and host synchornization.  However, the memory copy has to be
done through CUDA load/store which requires a kernel to run on SMs.
Thus, the memory copy and compute kernel overlapping will require a
careful SMs management.

Based on the tradeoffs explained above, we chose to implement (2)
to avoid the host and device sync point.

Differential Revision: D58631974
  • Loading branch information
sryap authored and facebook-github-bot committed Jul 14, 2024
1 parent 27ef127 commit cb0be42
Showing 1 changed file with 87 additions and 51 deletions.
138 changes: 87 additions & 51 deletions fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __init__(
self.ssd_event_evict_sp = torch.cuda.Event()

self.timesteps_prefetched: List[int] = []
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor]] = []
self.ssd_scratch_pads: List[Tuple[Tensor, Tensor, Tensor, bool]] = []
# TODO: add type annotation
# pyre-fixme[4]: Attribute must be annotated.
self.ssd_prefetch_data = []
Expand Down Expand Up @@ -397,48 +397,71 @@ def to_pinned_cpu(self, t: torch.Tensor) -> torch.Tensor:

def evict(
self,
evicted_rows: Tensor,
evicted_indices: Tensor,
rows: Tensor,
indices: Tensor,
actions_count_cpu: Tensor,
eviction_stream: torch.cuda.Stream,
stream: torch.cuda.Stream,
pre_event: torch.cuda.Event,
post_event: torch.cuda.Event,
is_rows_uvm: bool,
) -> None:
"""
Evict data from the given input tensors to SSD via RocksDB
Args:
rows (Tensor): The 2D tensor that contains rows to evict
indices (Tensor): The 1D tensor that contains the row indices that
the rows will be evicted to
actions_count_cpu (Tensor): A scalar tensor that contains the
number of rows that the evict function
has to process
stream (Stream): The CUDA stream that cudaStreamAddCallback will
synchronize the host function with. Moreover, the
asynchronous D->H memory copies will operate on
this stream
pre_event (Event): The CUDA event that the stream has to wait on
post_event (Event): The CUDA event that the current will record on
when the eviction is done
is_rows_uvm (bool): A flag to indicate whether `rows` is a UVM
tensor (which is accessible on both host and
device)
Returns:
None
"""
with torch.cuda.stream(eviction_stream):
eviction_stream.wait_event(pre_event)
with torch.cuda.stream(stream):
stream.wait_event(pre_event)

evicted_rows_cpu = self.to_pinned_cpu(evicted_rows)
evicted_indices_cpu = self.to_pinned_cpu(evicted_indices)
rows_cpu = rows if is_rows_uvm else self.to_pinned_cpu(rows)
indices_cpu = self.to_pinned_cpu(indices)

evicted_rows.record_stream(eviction_stream)
evicted_indices.record_stream(eviction_stream)
rows.record_stream(stream)
indices.record_stream(stream)

self.ssd_db.set_cuda(
evicted_indices_cpu, evicted_rows_cpu, actions_count_cpu, self.timestep
indices_cpu, rows_cpu, actions_count_cpu, self.timestep
)

# TODO: is this needed?
# Need a way to synchronize
# actions_count_cpu.record_stream(self.ssd_stream)
eviction_stream.record_event(post_event)
stream.record_event(post_event)

def _evict_from_scratch_pad(self, grad: Tensor) -> None:
assert len(self.ssd_scratch_pads) > 0, "There must be at least one scratch pad"
(inserted_rows_gpu, post_bwd_evicted_indices, actions_count_cpu) = (
(inserted_rows, post_bwd_evicted_indices, actions_count_cpu, do_evict) = (
self.ssd_scratch_pads.pop(0)
)
torch.cuda.current_stream().record_event(self.ssd_event_backward)
self.evict(
inserted_rows_gpu,
post_bwd_evicted_indices,
actions_count_cpu,
self.ssd_stream,
self.ssd_event_backward,
self.ssd_event_evict_sp,
)
if do_evict:
torch.cuda.current_stream().record_event(self.ssd_event_backward)
self.evict(
inserted_rows,
post_bwd_evicted_indices,
actions_count_cpu,
self.ssd_stream,
self.ssd_event_backward,
self.ssd_event_evict_sp,
is_rows_uvm=True,
)

def _compute_cache_ptrs(
self,
Expand All @@ -447,7 +470,7 @@ def _compute_cache_ptrs(
linear_index_inverse_indices: torch.Tensor,
unique_indices_count_cumsum: torch.Tensor,
cache_set_inverse_indices: torch.Tensor,
inserted_rows_gpu: torch.Tensor,
inserted_rows: torch.Tensor,
unique_indices_length: torch.Tensor,
inserted_indices: torch.Tensor,
actions_count_cpu: torch.Tensor,
Expand All @@ -468,7 +491,7 @@ def _compute_cache_ptrs(
unique_indices_count_cumsum,
cache_set_inverse_indices,
self.lxu_cache_weights,
inserted_rows_gpu,
inserted_rows,
unique_indices_length,
inserted_indices,
)
Expand All @@ -477,14 +500,19 @@ def _compute_cache_ptrs(
with record_function("## ssd_scratch_pads ##"):
# Store scratch pad info for post backward eviction
self.ssd_scratch_pads.append(
(inserted_rows_gpu, post_bwd_evicted_indices, actions_count_cpu)
(
inserted_rows,
post_bwd_evicted_indices,
actions_count_cpu,
linear_cache_indices.numel() > 0,
)
)

# pyre-fixme[7]: Expected `Tensor` but got `Tuple[typing.Any, Tensor,
# typing.Any, Tensor]`.
return (
lxu_cache_ptrs,
inserted_rows_gpu,
inserted_rows,
post_bwd_evicted_indices,
actions_count_cpu,
)
Expand Down Expand Up @@ -522,42 +550,50 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
evicted_rows = self.lxu_cache_weights[
assigned_cache_slots.clamp(min=0).long(), :
]
inserted_rows = torch.empty(
evicted_rows.shape,
dtype=self.lxu_cache_weights.dtype,
pin_memory=True,
)

if linear_cache_indices.numel() > 0:
inserted_rows = torch.ops.fbgemm.new_managed_tensor(
torch.zeros(
1, device=self.current_device, dtype=self.lxu_cache_weights.dtype
),
evicted_rows.shape,
)
else:
inserted_rows = torch.empty(
evicted_rows.shape,
dtype=self.lxu_cache_weights.dtype,
device=self.current_device,
)

current_stream = torch.cuda.current_stream()

inserted_indices_cpu = self.to_pinned_cpu(inserted_indices)

# Ensure the previous iterations l3_db.set(..) has completed.
current_stream.wait_event(self.ssd_event_evict)
current_stream.wait_event(self.ssd_event_evict_sp)

self.ssd_db.get_cuda(
self.to_pinned_cpu(inserted_indices), inserted_rows, actions_count_cpu
)
if linear_cache_indices.numel() > 0:
self.ssd_db.get_cuda(inserted_indices_cpu, inserted_rows, actions_count_cpu)
current_stream.record_event(self.ssd_event_get)
# TODO: T123943415 T123943414 this is a big copy that is (mostly) unnecessary with a decent cache hit rate.
# Should we allocate on HBM?
inserted_rows_gpu = inserted_rows.cuda(non_blocking=True)

torch.ops.fbgemm.masked_index_put(
self.lxu_cache_weights,
assigned_cache_slots,
inserted_rows_gpu,
inserted_rows,
actions_count_gpu,
)

# Evict rows from cache to SSD
self.evict(
evicted_rows,
evicted_indices,
actions_count_cpu,
self.ssd_stream,
self.ssd_event_get,
self.ssd_event_evict,
)
if linear_cache_indices.numel() > 0:
# Evict rows from cache to SSD
self.evict(
evicted_rows,
evicted_indices,
actions_count_cpu,
self.ssd_stream,
self.ssd_event_get,
self.ssd_event_evict,
is_rows_uvm=False,
)

# TODO: keep only necessary tensors
self.ssd_prefetch_data.append(
Expand All @@ -567,7 +603,7 @@ def prefetch(self, indices: Tensor, offsets: Tensor) -> Optional[Tensor]:
linear_index_inverse_indices,
unique_indices_count_cumsum,
cache_set_inverse_indices,
inserted_rows_gpu,
inserted_rows,
unique_indices_length,
inserted_indices,
actions_count_cpu,
Expand All @@ -593,7 +629,7 @@ def forward(
prefetch_data = self.ssd_prefetch_data.pop(0)
(
lxu_cache_ptrs,
inserted_rows_gpu,
inserted_rows,
post_bwd_evicted_indices,
actions_count_cpu,
) = self._compute_cache_ptrs(*prefetch_data)
Expand Down Expand Up @@ -635,7 +671,7 @@ def forward(
# codegen/genscript/optimizer_args.py
ssd_tensors={
"row_addrs": lxu_cache_ptrs,
"inserted_rows": inserted_rows_gpu,
"inserted_rows": inserted_rows,
"post_bwd_evicted_indices": post_bwd_evicted_indices,
"actions_count": actions_count_cpu,
},
Expand Down

0 comments on commit cb0be42

Please sign in to comment.