description |
---|
#DLRM_inference #GPU_embedding_cache |
Presented in SOSP 2023.
Authors: Xiaoniu Song, Yiwen Zhang, Rong Chen, Haibo Chen (SJTU)
- UGache — A unified cache across multiple GPUs
- Exploit cross-GPU interconnects (e.g., NVLink, NVSwitch)
- Provide an optimal data placement with a unified abstraction for various GPU interconnects and bandwidths
- Performance bottleneck
- Fetch embedding entries
- Example: GNN training (31.3 ms)
- Embedding extraction: 20.7 ms
- Fetch missing data from host memory: 17.9 ms (out of 20.7 ms)
- Example: GNN training (31.3 ms)
- Memory limitation in a single GPU
- Fetch embedding entries
- Single GPU embedding cache
- Only let each GPU cache the hottest embeddings independently
- The hit rate is limited due to the single GPU’s memory capacity, and the time of embedding extraction still dominates the entire model
- Multi-GPU embedding cache
- Fail to address challenges related to cache policy and extraction mechanism
- Replication cache
- Directly port single-GPU cache to multi-GPU platform → Each GPU caches the hottest entries independently
- Drawbacks
- Each GPU covers similar requests
- Waste bandwidth across GPUs
- Partition cache
- Cache as many individual entries as possible and serve the majority of accesses through fast GPU interconnects
- Drawbacks
- Increasing cache capacities using a partition policy does not lead to proportional increases in hit rates
- Read-only
- Inference
- Pre-training → The embedding table is trained beforehand and distributed across different downstream workloads
- Batched access
- Only access a subset of embedding entries using batched sparse inputs as keys
- Skewed hotness
- Power-law distribution
- Predictable pattern workload
- The skewness of accessing embeddings is predictable and stable
- Extractor
- A factored extraction mechanism to extract embedding entries from multiple sources
- Statically dedicate GPU cores to access different sources
- Solver
- Find a balance between caching more distinct entries to improve global hit rate and caching more replicas to improve local hit rate
- Define a hotness metric to measure the access frequency for each entry
- Profile hardware platform’s information to estimate embedding extraction time
- Utilize Mixed-Integer Linear Programming to solve a cache policy to minimize the extraction time
- Integrated into TensorFlow and PyTorch.
- Two representative applications: DLRM inference, GNN training