Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] offload model weights to CPU conditionally #6317

Closed

Conversation

chenqianfzh
Copy link
Contributor

@chenqianfzh chenqianfzh commented Jul 10, 2024

We are developing the "conditional cpu-offload-weight" feature for vLLM, which is comparable to Hugging Face's Accelerate device_map='auto'. This democratizes access to vLLM, empowering a broader community of learners and researchers to engage with cutting-edge AI models. This democratizes access to vLLM, empowering a broader community of learners and researchers to engage with cutting-edge AI models.

To achieve conditional CPU offload, a new CLI parameter, cpu_offload_trigger_percent, whose default value is 0, will be added.

  • When cpu_offload_trigger_percent == 1, the conditional cpu offload feature is turned off, vLLM behaves just as now.
  • If 0<= cpu_offload_trigger_percent <1, vllm will try to load the weights to GPU until the memory percentage hits the limit of cpu_offload_trigger_percent. After that, the weight will be offloaded to CPU.

Test results show:

  1. When the percentage specified by the user is enough to hold the weights, the performance will be the same as of now.
  2. When the percentage specified by the user is insufficient to hold the weights, the vLLM will continue to work with higher latency. In cases of large models with very limited GPU resources, the latency could be very high. However, vLLM will still work and generate outputs.

A more detailed doc is given at https://docs.google.com/document/d/1qY9SJIRCLn6_p_2jHInW0E8qE2oukgOrH26v62jZy70/edit#heading=h.vrhc1mfm2tpm

@youkaichao
Copy link
Member

When the percentage specified by the user is insufficient to hold the weights, the vLLM will continue to work with some latency.

how is it possible?

how does it work with cudagraph?

@comaniac
Copy link
Collaborator

I have the same question about how to get the same latency with offloading. Base on the code change, the offloaded weights are transferred to GPU synchronously when needed without prefetching. This should introduce a significant latency overhead.

@chenqianfzh
Copy link
Contributor Author

chenqianfzh commented Jul 11, 2024

@youkaichao @comaniac Thanks for looking into my PR.

As mentioned in the PR description, "When the percentage specified by the user is enough to hold the weights, the performance will be the same as of now." That is, the CPU offload is not triggered when the GPU memories are enough and cudagraph will be applied as of now. So the performance will be the same.

The cpu-offload will kick in only when the percentage specified by the user is insufficient to hold the weights. Cudagraph will be turned off in such cases. That is why I called this feature "conditional". The latency will be much bigger in such cases. That is for sure. However, vllm will continue to work, which will be nice to the users with limited GPU resources.

The following code from worker.py in the PR is about turning on/ff cudagraph:

        all_weight_in_gpu = True
        for name, param in self.model_runner.model.named_parameters():
            if not param.is_cuda:
                all_weight_in_gpu = False
                break

        if all_weight_in_gpu:
            self._warm_up_model()

Similar logic is given in vllm/worker/model_runner.py in this PR.

Please let me know if you have any other concerns. Thanks.

@comaniac
Copy link
Collaborator

Oh I guess the confusion comes from this statement: When the percentage specified by the user is insufficient to hold the weights, the vLLM will continue to work with some latency. Maybe you could change "some latency" to "higher latency"...

Meanwhile, CPU offloading can be optimized in multiple ways. For example, we could prefetch weights with CUDA stream to hide data transfer latency. It might be a good idea to have an RFC to document the scope, roadmap and milestones.

In addition to that, we should think more about the API design. The name cpu_offload_trigger_percent could be confusing, because potentially we can not only just offload weights but kv-cache.

@chenqianfzh
Copy link
Contributor Author

Oh I guess the confusion comes from this statement: When the percentage specified by the user is insufficient to hold the weights, the vLLM will continue to work with some latency. Maybe you could change "some latency" to "higher latency"...

I revised the description as "2. When the percentage specified by the user is insufficient to hold the weights, the vLLM will continue to work with higher latency. In cases of large models with very limited GPU resources, the latency could be very high. However, vLLM still works and generates outputs.". Hope it explains better.

Meanwhile, CPU offloading can be optimized in multiple ways. For example, we could prefetch weights with CUDA stream to hide data transfer latency. It might be a good idea to have an RFC to document the scope, roadmap and milestones.

As to "prefetching", in the doc https://docs.google.com/document/d/1qY9SJIRCLn6_p_2jHInW0E8qE2oukgOrH26v62jZy70/edit#heading=h.vrhc1mfm2tpm(it is pretty long, so I just gave the link in the description"), I discussed about it in the section "What is Next". As of now, based on my test results, I don't see the point of "prefetching" in vLLM. Based on the tests on AWS a g6.12xlarge machine(that is the most powerful GPU machine I can get), the latency to transfer weight from CPU to GPU is tens and sometimes eve hundreds comparing the tensor multiplication time (the test results are also given in the good doc above).

I guess "prefetching" will be practical when the machines with much better transfer speed between CPU and GPU, such as GH200/BG200, have greater availability.

As to the roadmap, could you point me to the doc? I did some search in the repo and just found some threads about cpu_offload.

In addition to that, we should think more about the API design. The name cpu_offload_trigger_percent could be confusing, because potentially we can not only just offload weights but kv-cache.

Agree. How about I change the variable name to "weight_cpu_offload_trigger_percent"?

Thanks.

@comaniac
Copy link
Collaborator

The point is whatever the data transfer time is, it directly adds up to the forward latency without prefetching, and may become critical when inter-token latency during decoding is just a few milliseconds. Your result also shows that the e2e latency could be several times longer when offloading happens.

For the roadmap, it's just your "What's Next" section, but we could make it clearer by defining follow-up features along with their scopes and dependencies.

btw could you turn on commenting in the doc so that everyone could leave comments? Thanks.

@chenqianfzh
Copy link
Contributor Author

The point is whatever the data transfer time is, it directly adds up to the forward latency without prefetching, and may become critical when inter-token latency during decoding is just a few milliseconds. Your result also shows that the e2e latency could be several times longer when offloading happens.

I got your point. I listed "prefetching" as one my follow-up work item in the google doc shared. In the doc, I said I need to investigate more about prefetching.

Here is the one of the test results given in the doc:

Transformer layer weight load time | Matrix computation time for offload weight
176.3s                                                  0.9s

The total latency can only be cut by 0.9s if prefetching is in place, which is a gain of less than 1%. Besides, vllm architecture is designed in such a way that each sub-modules are encapsulated within its own boundary. "Prefetching" is supposed to break the boundary. It requires some cooperation between the current layer and the next layer. I need to spend some efforts to find a way which is pretty in respect of engineering.

I am more than happy to have any suggestion and discussion about that. I am a new-comer to vLLM as well as cpu-offloading. :-)

For the roadmap, it's just your "What's Next" section, but we could make it clearer by defining follow-up features along with their scopes and dependencies.

btw could you turn on commenting in the doc so that everyone could leave comments? Thanks.

Got it. The doc is open to comments now.

@drikster80
Copy link
Contributor

drikster80 commented Jul 15, 2024

This is exactly what I've been looking/waiting for. I'm currently running a small GH200 cluster (4 x 96G) and NCCL is getting ~42GB/s across the cluster. As @chenqianfzh mentioned above, I think the GH200 will see some major benefits from the CPU<>GPU NVLink.

Is there any reason that this wouldn't work across a ray nccl cluster?

UPDATE: I was able to confirm this works with pipeline-parallelism on multi-node:

python -m vllm.entrypoints.openai.api_server --model /models/Meta-Llama-3-70B-Instruct --pipeline-parallel-size 2 --cpu_offload_trigger_percent 0.50 --distributed-executor-backend ray

Currently getting 5.6 t/s over multi-node (x2 nodes) . (vs 4.6 t/s with a single node, which seems strange).

@drikster80
Copy link
Contributor

drikster80 commented Jul 16, 2024

Did some initial testing of this on the GH200 with Meta-Llama-3-70B-Instruct. Inference looks like 4.6 t/s, which seems surprisingly slow, given the increased speed between CPU<>GPU on GH200. As a reference, I'm seeing ~20 t/s when distributed between 2 nodes (no cpu-offload) with v0.5.1 (400Gbps network).

Is there something I'm not seeing here? Details below...

Exec flags:

python -m vllm.entrypoints.openai.api_server --model /models/Meta-Llama-3-70B-Instruct --cpu_offload_trigger_percent 0.80

Observations:

GPU memory used: 86224MiB/97871MiB
Avg generation throughput: 4.6 tokens/s

Details

INFO 07-15 23:59:07 api_server.py:206] vLLM API server version 0.5.1
INFO 07-15 23:59:07 api_server.py:207] args: Namespace(host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], model='/models/Meta-Llama-3-70B-Instruct', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=None, guided_decoding_backend='outlines', distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=1, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=False, disable_sliding_window=False, use_v2_block_manager=False, num_lookahead_slots=0, seed=0, swap_space=4, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=256, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, enforce_eager=False, max_context_len_to_capture=None, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, enable_lora=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', scheduler_delay_factor=0.0, enable_chunked_prefill=False, speculative_model=None, num_speculative_tokens=None, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, model_loader_extra_config=None, preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, cpu_offload_trigger_percent=0.8, engine_use_ray=False, disable_log_requests=False, max_log_len=None)
INFO 07-15 23:59:07 llm_engine.py:174] Initializing an LLM engine (v0.5.1) with config: model='/models/Meta-Llama-3-70B-Instruct', speculative_config=None, tokenizer='/models/Meta-Llama-3-70B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=8192, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None), seed=0, served_model_name=/models/Meta-Llama-3-70B-Instruct, use_v2_block_manager=False, enable_prefix_caching=False)
INFO 07-16 00:03:30 model_runner.py:268] Loading model weights took 77.9497 GB
INFO 07-16 00:03:32 gpu_executor.py:86] # GPU blocks: 921, # CPU blocks: 819
<...SNIP...>
INFO 07-16 00:06:11 metrics.py:295] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 4.6 tokens/s, Running: 1 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 1.8%, CPU KV cache usage: 0.0%.

@chenqianfzh
Copy link
Contributor Author

871

Thanks for trying it on a GH200 machine.

Using cpu-offload is expected to have a bigger latency due to the following two factors:

  1. The extra latency to load the weights from CPU to GPU.
  2. the cuda graph optimization has to be turned off as some weights are not in GPU.

The latency caused by the second factor can not be alleviated by the great speed between GPU & CPU.

Wonder whether you have some profiling tools to identify the to latencies in tensor moving and the weight calculation(with the un-optimized cuda-graph). With the data, we will have better idea on how to optimize more.

Thanks

@drikster80
Copy link
Contributor

Wonder whether you have some profiling tools to identify the to latencies in tensor moving and the weight calculation(with the un-optimized cuda-graph)

Haven't done something like that before, but will do some research. If anyone has any advice or tools for profiling in vllm and pytorch memory usage, please reply and I'll get into it.

Have been thinking about this statement:

The total latency can only be cut by 0.9s if prefetching is in place

If the prefetching is done as a sliding window on the sequential layers, with the window size of cpu_offload_trigger_percent, then you'd likely see much better gains than 0.9s. For a simple example, on a 100 layer model with 50% offloaded to system memory, when layer 1 is complete, you'd want to offload layer 1 and import layer 51. The matmul would still catch-up, pretty quickly, but on system with fast CPU<>GPU speeds, it would be increasingly efficient.

I'm brand new to the vllm architecture, so I'm not the person to implement this, but I'm happy to perform testing or give access to hardware (I've got A100s with both NVlink and non, as well as a few GH200s laying around to play with).

@drikster80
Copy link
Contributor

I did some quick profiling with nsys and a modified benchmark_latency.py that added the cpu_offload_trigger_percent. First time using nsight profiling, so still learning, but the output looks like what is expected. Vast majority of the time (99.9%) is spent on CUDA memcpy Host-to-Device

Profiling iterations: 100%|██████████| 30/30 [13:54<00:00, 27.83s/it]
Avg latency: 27.825364518566765 seconds
10% percentile latency: 27.748035688600247 seconds
25% percentile latency: 27.78808914725016 seconds
50% percentile latency: 27.822442912999804 seconds
75% percentile latency: 27.870453400750193 seconds
90% percentile latency: 27.89430141380053 seconds
99% percentile latency: 27.96023444084967 seconds
Generating '/tmp/nsys-report-30ec.qdstrm'
[1/7] [========================100%] report1.nsys-rep
[2/7] [========================100%] report1.sqlite
[3/7] Executing 'nvtx_sum' stats report
SKIPPED: /vllm-workspace/vllm/benchmarks/report1.sqlite does not contain NV Tools Extension (NVTX) data.
[4/7] Executing 'cuda_api_sum' stats report

Time (%)  Total Time (ns)  Num Calls   Avg (ns)    Med (ns)   Min (ns)   Max (ns)   StdDev (ns)                        Name
 --------  ---------------  ---------  ----------  ----------  --------  ----------  -----------  ------------------------------------------------
     95.7    1126506416224    1219248    923935.4    370272.0      2048  7223082048   19872307.5  cudaMemcpyAsync
      1.9      22880455968    3084874      7417.0      7552.0      1760    40032992      33868.7  cudaLaunchKernel
      1.2      13746969600    1643841      8362.7      8256.0      4896      677760       2738.2  cudaLaunchKernelExC_v11060
      0.6       6888177792        288  23917284.0  11644192.0      3616   186973440   27176905.9  cudaMalloc
      0.4       4238695072     836803      5065.3      4992.0      1248    46411520      50774.8  cudaMemsetAsync
      0.2       1852857600     753236      2459.9      2240.0      1376      648384       1822.3  cudaStreamSynchronize
      0.0        335088096         89   3765034.8   3019232.0      3264    70269312    7294763.5  cudaHostAlloc
      0.0        162171552          2  81085776.0  81085776.0     45440   162126112  114608342.3  cudaMemGetInfo
      0.0         71116352      56322      1262.7       480.0       288      742720       3414.3  cudaEventQuery
      0.0         56782912      56331      1008.0       704.0       352       62720        784.4  cudaEventRecord
      0.0          6518432      10529       619.1       704.0       160        7904        382.6  cudaStreamIsCapturing_v10000
      0.0          5271872         10    527187.2    584672.0     21152      857728     269456.5  cuLibraryLoadData
      0.0          4260256          9    473361.8    327392.0    184960     1120832     340253.7  cudaFree
      0.0           292000          1    292000.0    292000.0    292000      292000          0.0  cudaStreamCreate
      0.0           166496        812       205.0       224.0        32        1024        148.0  cuGetProcAddress_v2
      0.0            15680         27       580.7       256.0       160        2912        812.2  cudaEventCreateWithFlags
      0.0            13344          2      6672.0      6672.0      3936        9408       3869.3  cudaDeviceSynchronize
      0.0            13088         31       422.2       352.0       160        2528        495.9  cudaOccupancyMaxActiveClusters_v11070
      0.0             7488          3      2496.0      1952.0      1568        3968       1289.2  cuInit
      0.0             7040          1      7040.0      7040.0      7040        7040          0.0  cudaStreamDestroy
      0.0             6560         10       656.0       640.0       384         992        176.1  cuLibraryGetKernel
      0.0             3872          3      1290.7       128.0       128        3616       2013.8  cuModuleGetLoadingMode
      0.0             3808          2      1904.0      1904.0      1728        2080        248.9  cudaOccupancyAvailableDynamicSMemPerBlock_v10200
      0.0             1088          2       544.0       544.0       224         864        452.5  cudaGetDriverEntryPoint_v11030

[5/7] Executing 'cuda_gpu_kern_sum' stats report

 Time (%)  Total Time (ns)  Instances  Avg (ns)   Med (ns)   Min (ns)  Max (ns)  StdDev (ns)                                                  Name
 --------  ---------------  ---------  ---------  ---------  --------  --------  -----------  ----------------------------------------------------------------------------------------------------
     78.8     177724300867    1219200   145771.2   137728.0     40256    262560      87603.7  sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x128x64_warpgroupsize1x1x1_execute_segment_k_off…
      9.6      21552016106     406400    53031.5    52928.0     50624     62496       1086.0  sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x64x64_warpgroupsize1x1x1_execute_segment_k_on_k…
      3.0       6695658789     409680    16343.6    16160.0     15360    654368       8793.6  void vllm::act_and_mul_kernel<c10::BFloat16, &vllm::silu_kernel<c10::BFloat16>>(T1 *, const T1 *, i…
      1.7       3746577919     406400     9218.9     8352.0      7936     13504       1567.6  void flash_fwd_splitkv_kernel<Flash_fwd_kernel_traits<(int)128, (int)64, (int)128, (int)4, (bool)0,…
      1.4       3139384076     819360     3831.5     3776.0      3520    153728       2071.0  std::enable_if<T2>(int)0&&vllm::_typeConvert<T1>::exists, void>::type vllm::fused_add_rms_norm_kern…
      1.3       2901027877       5120   566607.0   566496.0    562464    581568       1689.6  sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x128x64_warpgroupsize1x1x1_execute_segment_k_on_…
      1.2       2816592811       9761   288555.8   171200.0     54400  10552160     917616.3  sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_of…
      1.0       2185447331     409680     5334.5     5248.0      4768    181824       2433.4  void vllm::rotary_embedding_kernel<c10::BFloat16, (bool)1>(const long *, T1 *, T1 *, const T1 *, in…
      0.5       1238684447     406400     3047.9     3040.0      2912      3840         41.1  void flash_fwd_splitkv_combine_kernel<Flash_fwd_kernel_traits<(int)128, (int)64, (int)128, (int)4, …
      0.5       1175692800     409600     2870.3     2848.0      2560      4384        136.3  void vllm::reshape_and_cache_flash_kernel<c10::BFloat16>(const T1 *, const T1 *, T1 *, T1 *, const …
      0.2        466115078        160  2913219.2  2798224.0   1250432   4655040    1594857.7  sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize256x128x64_warpgroupsize2x1x1_execute_segment_k_of…
      0.1        279222913       5121    54525.1    54464.0     53728    191328       1932.0  void at::native::<unnamed>::cunn_SoftMaxForward<(int)4, float, float, float, at::native::<unnamed>:…
      0.1        272676446       3200    85211.4    84864.0     83104     93600       1469.7  sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize128x256x64_warpgroupsize2x1x1_execute_segment_k_of…
      0.1        227309855       5121    44387.8    44352.0     43840    164064       1685.0  void at::native::<unnamed>::cunn_SoftMaxForward<(int)4, float, float, float, at::native::<unnamed>:…
      0.1        186346016      20484     9097.1     9024.0      8512     70848        883.7  void at::native::mbtopk::radixFindKthValues<float, unsigned int, unsigned int, (int)2>(at::cuda::de…
      0.1        144587840       5121    28234.3    28224.0     27360     55904        457.3  void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<float, at::native::ArgMaxOps<…
      0.0        101521471      20484     4956.1     5888.0      2400    160448       2745.8  void at::native::index_elementwise_kernel<(int)128, (int)4, void at::native::gpu_index_kernel<void …
      0.0         84796384       5121    16558.6    16544.0     16224     80608        902.3  void at::native::reduce_kernel<(int)512, (int)1, at::native::ReduceOp<long, at::native::func_wrappe…
      0.0         75569664       3280    23039.5    15392.0     14592    344256      48637.3  void flash_fwd_kernel<Flash_fwd_kernel_traits<(int)128, (int)128, (int)64, (int)4, (bool)0, (bool)0…
      0.0         66707007      10200     6539.9     6016.0      5376      9664       1003.2  void at::native::<unnamed>::indexSelectSmallIndex<c10::BFloat16, long, unsigned int, (int)2, (int)2…
      0.0         54025632      20484     2637.5     2656.0      2496      5376         72.6  void at::native::mbtopk::computeBlockwiseWithinKCounts<unsigned int>(T1 *, short *, unsigned int, i…
      0.0         41242368      10242     4026.8     4064.0      3744    100224       1341.0  void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kernel_cuda(at::TensorIterator…
      0.0         40584864       5121     7925.2     7872.0      6976    130880       1736.4  void at::native::mbtopk::gatherTopK<float, unsigned int, (int)2>(at::cuda::detail::TensorInfo<const…
      0.0         33501152       5121     6541.9     6464.0      6176    211392       2869.7  void vllm::rms_norm_kernel<c10::BFloat16>(T1 *, const T1 *, const T1 *, float, int, int)
      0.0         31939776      10242     3118.5     3136.0      2880      3808        173.0  void at_cuda_detail::cub::DeviceScanByKeyKernel<at_cuda_detail::cub::DeviceScanByKeyPolicy<at_cuda_…
      0.0         31451712       5121     6141.7     6112.0      5568     99040       1309.3  void at::native::elementwise_kernel<(int)128, (int)4, void at::native::gpu_kernel_impl_nocast<at::n…
      0.0         24348960       5122     4753.8     4736.0      2272    109344       1465.2  void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kernel_cuda(at::TensorIterator…
      0.0         24226112       5121     4730.7     4704.0      4640     57216        735.7  void at::native::<unnamed>::distribution_elementwise_grid_stride_kernel<float, (int)4, void at::nat…
      0.0         23782880       5121     4644.2     4640.0      4448     90208       1197.0  void at::native::elementwise_kernel<(int)128, (int)4, void at::native::gpu_kernel_impl_nocast<at::n…
      0.0         17712032       5121     3458.7     3520.0      2944    109504       1493.7  void at::native::vectorized_elementwise_kernel<(int)4, at::native::BinaryFunctor<float, float, floa…
      0.0         15707968      10242     1533.7     1568.0      1344      2336        132.5  void at_cuda_detail::cub::DeviceScanKernel<at_cuda_detail::cub::DeviceScanPolicy<int, std::plus<int…
      0.0         12496000       5121     2440.1     2432.0      2208      3136         86.6  void at::native::elementwise_kernel<(int)128, (int)4, void at::native::gpu_kernel_impl<at::native::…
      0.0         11663104      10242     1138.8     1248.0       928      1664        155.2  void at_cuda_detail::cub::DeviceScanByKeyInitKernel<at_cuda_detail::cub::ReduceByKeyScanTileState<u…
      0.0         10593216       5121     2068.6     2080.0      1984      5056         55.9  void at::native::mbtopk::computeBlockwiseKthCounts<unsigned int>(T1 *, short *, unsigned int, unsig…
      0.0         10012865      10242      977.6      992.0       864      1248         50.8  void at_cuda_detail::cub::DeviceScanInitKernel<at_cuda_detail::cub::ScanTileState<int, (bool)1>>(T1…
      0.0          9535744      10242      931.0      960.0       736      1472        121.2  void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<int>, at::detail::Ar…
      0.0          6553888       5121     1279.8     1280.0      1152      1728         46.7  void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kernel_cuda(at::TensorIterator…
      0.0          6363520       5121     1242.6     1248.0      1184      1600         22.3  void at::native::vectorized_elementwise_kernel<(int)4, at::native::CUDAFunctorOnSelf_add<long>, at:…
      0.0          5494624       5122     1072.7     1088.0      1024      2432         25.5  void <unnamed>::elementwise_kernel_with_index<int, at::native::arange_cuda_out(const c10::Scalar &,…
      0.0          5014944       5121      979.3      992.0       896      1472         47.2  void at::native::mbtopk::fill<unsigned int, unsigned int>(T1 *, T1, T2)
      0.0          3157216        241    13100.5     1472.0      1056     37312      16864.6  void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<c10::BFloat16>, at::…
      0.0          1016576         42    24204.2    15696.0     15488    368864      54480.7  void at::native::<unnamed>::indexSelectLargeIndex<c10::BFloat16, long, unsigned int, (int)2, (int)2…
      0.0           773632          2   386816.0   386816.0    385152    388480       2353.3  void at_cuda_detail::cub::DeviceSegmentedRadixSortKernel<at_cuda_detail::cub::DeviceRadixSortPolicy…
      0.0           650080          1   650080.0   650080.0    650080    650080          0.0  void at::native::_scatter_gather_elementwise_kernel<(int)128, (int)4, void at::native::_cuda_scatte…
      0.0           327360          1   327360.0   327360.0    327360    327360          0.0  void at_cuda_detail::cub::DeviceSegmentedRadixSortKernel<at_cuda_detail::cub::DeviceRadixSortPolicy…
      0.0           263872          1   263872.0   263872.0    263872    263872          0.0  void at::native::tensor_kernel_scan_innermost_dim<c10::BFloat16, std::plus<c10::BFloat16>>(T1 *, co…
      0.0           223552          2   111776.0   111776.0      3520    220032     153097.1  void at::native::_scatter_gather_elementwise_kernel<(int)128, (int)4, void at::native::_cuda_scatte…
      0.0           181760          2    90880.0    90880.0     90528     91232        497.8  void at::native::elementwise_kernel<(int)128, (int)4, void at::native::gpu_kernel_impl_nocast<at::n…
      0.0           117152          1   117152.0   117152.0    117152    117152          0.0  at::native::<unnamed>::fill_reverse_indices_kernel(long *, int, at::cuda::detail::IntDivider<unsign…
      0.0            94080          1    94080.0    94080.0     94080     94080          0.0  void at::native::<unnamed>::cunn_SoftMaxForward<(int)8, c10::BFloat16, float, c10::BFloat16, at::na…
      0.0            90432          2    45216.0    45216.0     45184     45248         45.3  void at::native::vectorized_elementwise_kernel<(int)4, at::native::<unnamed>::masked_fill_kernel(at…
      0.0             4320          1     4320.0     4320.0      4320      4320          0.0  void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kernel_cuda(at::TensorIterator…
      0.0             4192          1     4192.0     4192.0      4192      4192          0.0  void at::native::<unnamed>::CatArrayBatchedCopy_aligned16_contig<at::native::<unnamed>::OpaqueType<…
      0.0             3488          1     3488.0     3488.0      3488      3488          0.0  void at::native::elementwise_kernel<(int)128, (int)2, void at::native::gpu_kernel_impl_nocast<at::n…
      0.0             2688          1     2688.0     2688.0      2688      2688          0.0  void at::native::elementwise_kernel<(int)128, (int)4, void at::native::gpu_kernel_impl<at::native::…
      0.0             2496          2     1248.0     1248.0      1152      1344        135.8  void <unnamed>::elementwise_kernel_with_index<int, at::native::arange_cuda_out(const c10::Scalar &,…
      0.0             2400          1     2400.0     2400.0      2400      2400          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::sin_kernel_cuda(at::TensorIterat…
      0.0             2368          1     2368.0     2368.0      2368      2368          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::cos_kernel_cuda(at::TensorIterat…
      0.0             1952          1     1952.0     1952.0      1952      1952          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::BUnaryFunctor<float, float, floa…
      0.0             1728          1     1728.0     1728.0      1728      1728          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::CUDAFunctorOnOther_add<c10::BFlo…
      0.0             1696          1     1696.0     1696.0      1696      1696          0.0  void at::native::elementwise_kernel<(int)128, (int)4, void at::native::gpu_kernel_impl_nocast<at::n…
      0.0             1440          1     1440.0     1440.0      1440      1440          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::reciprocal_kernel_cuda(at::Tenso…
      0.0             1280          1     1280.0     1280.0      1280      1280          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::AUnaryFunctor<float, float, floa…
      0.0             1280          1     1280.0     1280.0      1280      1280          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::CUDAFunctorOnOther_add<long>, at…
      0.0             1088          1     1088.0     1088.0      1088      1088          0.0  void at::native::vectorized_elementwise_kernel<(int)4, at::native::FillFunctor<double>, at::detail:…

[6/7] Executing 'cuda_gpu_mem_time_sum' stats report

 Time (%)  Total Time (ns)  Count   Avg (ns)   Med (ns)  Min (ns)   Max (ns)   StdDev (ns)            Operation
 --------  ---------------  ------  ---------  --------  --------  ----------  -----------  ------------------------------
     99.9    1065985331278  783962  1359741.1  449824.0       704  7223016612   24748031.0  [CUDA memcpy Host-to-Device]
      0.1        663054081  836803      792.4     768.0       704        2496         75.3  [CUDA memset]
      0.1        586290338  409681     1431.1    1376.0      1344       94592       1303.4  [CUDA memcpy Device-to-Device]
      0.0         50882912   25605     1987.2    1440.0      1312        5024        768.9  [CUDA memcpy Device-to-Host]

[7/7] Executing 'cuda_gpu_mem_size_sum' stats report

  Total (MB)    Count   Avg (MB)  Med (MB)  Min (MB)  Max (MB)  StdDev (MB)            Operation
 -------------  ------  --------  --------  --------  --------  -----------  ------------------------------
 294088757.097  783962   375.131   167.772     0.000  2101.346      335.997  [CUDA memcpy Host-to-Device]
     77492.519  409681     0.189     0.131     0.131   134.218        1.910  [CUDA memcpy Device-to-Device]
        81.603  836803     0.000     0.000     0.000     0.008        0.000  [CUDA memset]
         1.319   25605     0.000     0.000     0.000     0.002        0.000  [CUDA memcpy Device-to-Host]

Going to mess with the Nsight UI and see if I can analyze the nsys-rep a little better. Will also be testing #6496 today as well and will run the same benchmark profiling.

@chenqianfzh
Copy link
Contributor Author

Wonder whether you have some profiling tools to identify the to latencies in tensor moving and the weight calculation(with the un-optimized cuda-graph)

Haven't done something like that before, but will do some research. If anyone has any advice or tools for profiling in vllm and pytorch memory usage, please reply and I'll get into it.

Have been thinking about this statement:

The total latency can only be cut by 0.9s if prefetching is in place

If the prefetching is done as a sliding window on the sequential layers, with the window size of cpu_offload_trigger_percent, then you'd likely see much better gains than 0.9s. For a simple example, on a 100 layer model with 50% offloaded to system memory, when layer 1 is complete, you'd want to offload layer 1 and import layer 51. The matmul would still catch-up, pretty quickly, but on system with fast CPU<>GPU speeds, it would be increasingly efficient.

I'm brand new to the vllm architecture, so I'm not the person to implement this, but I'm happy to perform testing or give access to hardware (I've got A100s with both NVlink and non, as well as a few GH200s laying around to play with).

Thank you so much for your test and suggestion. Let me think through your scheme.

@chenqianfzh
Copy link
Contributor Author

Closing it as the cpu-offload is implemented and merged in #6496

@chenqianfzh chenqianfzh deleted the qian/conditional-cpu-offload branch August 30, 2024 00:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants