From 40f0d62d05521e2912c7371f4524d3f84d747224 Mon Sep 17 00:00:00 2001 From: Siyang Shao Date: Fri, 20 Sep 2024 17:03:17 +0800 Subject: [PATCH] Dev/llama3 (#7) * llama support * flash_attention * sharded * expend * fix: remove redunctant info * change main * llama and opt model supported --------- Co-authored-by: Shao Siyang FYP PDCL Co-authored-by: lairuiqi Co-authored-by: LaiRuiqi <58351056+lrq619@users.noreply.github.com> --- main.py | 8 +- vanilla.py | 58 ++++ vllm/attention/backends/flash_attn_liquid.py | 4 +- vllm/engine/llm_engine.py | 8 +- vllm/liquid/model_executor/layers/linear.py | 211 +++++++++++++- .../layers/vocab_parallel_embedding.py | 6 +- vllm/liquid/sharded_parameter.py | 201 +++++++++---- vllm/liquid/utils.py | 2 +- vllm/model_executor/models/llama_liquid.py | 274 ++++++++++++++++-- vllm/worker/model_runner.py | 5 +- vllm/worker/worker.py | 2 +- 11 files changed, 674 insertions(+), 105 deletions(-) create mode 100644 vanilla.py diff --git a/main.py b/main.py index 34b11e64e6679..12f53598a4659 100644 --- a/main.py +++ b/main.py @@ -16,11 +16,12 @@ def main(): enforce_eager=True, # load_format="auto", # tensor_parallel_size=2, + # liquid_gpu_range = [0,1,2,3], liquid_gpu_range = [0,1,2,3], liquid_gpu_space = 32, liquid_driver_gpu_id = 0, liquid_total_num_shards = 4, - # gpu_memory_utilization=0.8, + ) sampling_params = SamplingParams(temperature=0, min_tokens=128, max_tokens=128) request_num = 1 @@ -37,8 +38,7 @@ def main(): llm.do_liquid(liquid_request) liquid_request = LiquidRequest(LiquidType.LIQUID_2_1) llm.do_liquid(liquid_request) - # liquid_request = LiquidRequest(LiquidType.LIQUID_1_2) - # llm.do_liquid(liquid_request) + output = llm.generate(inputs, sampling_params=sampling_params) @@ -53,4 +53,4 @@ def main(): main() # torch.cuda.memory._dump_snapshot(f"./torch_mem_dump.pickle") # torch.cuda.memory._record_memory_history(enabled=None) - # print(f"dumped finished!") \ No newline at end of file + # print(f"dumped finished!") diff --git a/vanilla.py b/vanilla.py new file mode 100644 index 0000000000000..eb7b3fe1e6106 --- /dev/null +++ b/vanilla.py @@ -0,0 +1,58 @@ + +from vllm import LLM, SamplingParams +from vllm.liquid.request import LiquidRequest, LiquidType +# from vllm import EngineArgs, LLMEngine +import asyncio +import torch + +import os + +model = "meta-llama/Meta-Llama-3-8B" +# model = "facebook/opt-6.7b" +# model_path = os.path.join("./models", model) + +def main(): + llm = LLM( + model, + enforce_eager=True, + # load_format="auto", + tensor_parallel_size=2, + # liquid_gpu_range = [0,1,2,3], + # liquid_gpu_space = 32, + # liquid_driver_gpu_id = 0, + # liquid_total_num_shards = 4, + gpu_memory_utilization=0.8, + ) + sampling_params = SamplingParams(temperature=0, min_tokens=128, max_tokens=128) + request_num = 1 + word = "what is LLM?" + prompt = word + inputs = [prompt for _ in range(request_num)] + +# for i in range(1): +# print(f"i: {i}") +# liquid_request = LiquidRequest(LiquidType.LIQUID_1_2) +# llm.do_liquid(liquid_request) +# # liquid_request = LiquidRequest(LiquidType.LIQUID_2_4) +# # llm.do_liquid(liquid_request) +# # liquid_request = LiquidRequest(LiquidType.LIQUID_4_2) +# # llm.do_liquid(liquid_request) +# liquid_request = LiquidRequest(LiquidType.LIQUID_2_1) +# llm.do_liquid(liquid_request) + +# print("liquid done") + + + output = llm.generate(inputs, sampling_params=sampling_params) + print(f"output: {output[0].outputs[0].text}") + + + + + +if __name__ == '__main__': + # torch.cuda.memory._record_memory_history(context="all", stacks="all") + main() + # torch.cuda.memory._dump_snapshot(f"./torch_mem_dump.pickle") + # torch.cuda.memory._record_memory_history(enabled=None) + # print(f"dumped finished!") diff --git a/vllm/attention/backends/flash_attn_liquid.py b/vllm/attention/backends/flash_attn_liquid.py index cf994f2eb4edc..f41c12a785bd0 100644 --- a/vllm/attention/backends/flash_attn_liquid.py +++ b/vllm/attention/backends/flash_attn_liquid.py @@ -265,7 +265,7 @@ def __init__( def delete_shard(self, shard_id: int): assert shard_id in self.shard_ids - self.num_heads -= self.num_kv_heads_per_shard + self.num_heads -= self.num_heads_per_shard self.num_kv_heads -= self.num_kv_heads_per_shard index = self.shard_ids.index(shard_id) @@ -273,7 +273,7 @@ def delete_shard(self, shard_id: int): def append_shard(self, shard_id: int): assert shard_id not in self.shard_ids - self.num_heads += self.num_kv_heads_per_shard + self.num_heads += self.num_heads_per_shard self.num_kv_heads += self.num_kv_heads_per_shard self.shard_ids.append(shard_id) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1a6097c0cb79f..e5b593da78cc7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -221,7 +221,8 @@ def __init__( self.liquid_config = liquid_config self.liquid_request_queue: Queue[LiquidRequest] = Queue() self.execution_lock: threading.Lock = threading.Lock() - self.auto_scaler = AutoScaler(liquid_config=liquid_config) + if liquid_config is not None: + self.auto_scaler = AutoScaler(liquid_config=liquid_config) self.request_output_queue: Queue[RequestOutput] = Queue() if not self.model_config.skip_tokenizer_init: @@ -832,8 +833,9 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """ # self.model_executor.delete_kv_cache() cache_usage = self.get_latest_metrics().gpu_cache_usage - # liquid_request = None - liquid_request = self.auto_scaler.step(cache_usage) + liquid_request = None + if self.liquid_config is not None: + liquid_request = self.auto_scaler.step(cache_usage) if liquid_request is not None: self.liquid_request_queue.put(liquid_request) diff --git a/vllm/liquid/model_executor/layers/linear.py b/vllm/liquid/model_executor/layers/linear.py index ed53783dc5625..98aeb24d10dce 100644 --- a/vllm/liquid/model_executor/layers/linear.py +++ b/vllm/liquid/model_executor/layers/linear.py @@ -4,7 +4,7 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter -from vllm.liquid.sharded_parameter import ShardedParameter, QKVShardedParameter +from vllm.liquid.sharded_parameter import ShardedParameter, QKVShardedParameter,GateUpShardedParameter from vllm.distributed import (divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -94,15 +94,28 @@ def create_weights(self, layer: torch.nn.Module, shard_dim: int = -1, param_class = ShardedParameter, **extra_weight_attrs): - weight = param_class( - data=torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - num_shards=len(shard_ids), - shard_dim=shard_dim, - shard_ids=shard_ids, - requires_grad=False, - ) + if param_class == QKVShardedParameter: + weight = QKVShardedParameter( + data=torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + num_shards=len(shard_ids), + shard_dim=shard_dim, + shard_ids=shard_ids, + requires_grad=False, + num_heads_ratio=extra_weight_attrs['num_heads_ratio'], + num_kv_heads_ratio=extra_weight_attrs['num_kv_heads_ratio'], + ) + else: + weight = param_class( + data=torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + num_shards=len(shard_ids), + shard_dim=shard_dim, + shard_ids=shard_ids, + requires_grad=False, + ) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) @@ -276,6 +289,8 @@ def __init__(self, shard_ids: List[int] = [0], total_num_shards: int = 1, param_class = ShardedParameter, + num_heads_ratio: int=1, + num_kv_heads_ratio: int=1, ): super().__init__(input_size, output_size, skip_bias_add, params_dtype, quant_config) @@ -310,6 +325,8 @@ def __init__(self, shard_ids=shard_ids, shard_dim=shard_dim, param_class=param_class, + num_heads_ratio=num_heads_ratio, + num_kv_heads_ratio=num_kv_heads_ratio, ) if bias: self.bias = param_class( @@ -446,6 +463,8 @@ def __init__(self, shard_ids=shard_ids, total_num_shards=total_num_shards, param_class=QKVShardedParameter, + num_heads_ratio=self.num_heads, + num_kv_heads_ratio=self.num_kv_heads, ) def weight_loader(self, @@ -737,3 +756,175 @@ def extra_repr(self) -> str: s += f", tp_size={self.tp_size}" s += f", reduce_results={self.reduce_results}" return s + + +class MergedColumnParallelLinear(ColumnParallelLinear): + """Packed linear layers with column parallelism. + + Similar to ColumnParallelLinear, but the weight matrix is concatenated + along the output dimension. When the weight matrix is loaded, the + different partitions are sharded separately. + + Args: + input_size: input dimension of the linear layer. + output_sizes: list of output dimensions of the linear layer. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make the output + available to all GPUs, otherwise, every GPU will have + its own output. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__(self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + shard_ids: List[int] = [0], + total_num_shards: int = 1,): + self.output_sizes = output_sizes + # tp_size = get_tensor_model_parallel_world_size() + # assert all(output_size % tp_size == 0 for output_size in output_sizes) + super().__init__(input_size=input_size, + output_size=sum(output_sizes), + bias=bias, + gather_output=gather_output, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + shard_ids=shard_ids, + total_num_shards=total_num_shards, + param_class=GateUpShardedParameter, + ) + + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None): + + param_data = param.data + output_dim = getattr(param, "output_dim", None) + # Special case for AQLM codebooks. + is_metadata = getattr(param, "is_metadata", False) + + param_shard_splitter = getattr(param, "shard_splitter", None) + + if output_dim is not None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support output_dim != None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + # If a parameter has defined a shard_splitter to be used for + # the weight, it should be applied before the weight is + # loaded/copied to the parameter. The shard_splitter applies + # logic by using the loaded_shard_id to ensure that the loaded + # param is loaded to the correct location + # within the parameter defined by the linear method. + if loaded_shard_id is None and param_shard_splitter is not None: + raise NotImplementedError( + "We do not currently support loaded_shard_id == None and " + "shard_splitter != None for a parameter. Please open an issue." + ) + + # Special case for Fp8 scales. + fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer", + None) + + if loaded_shard_id is None: + # Loaded weight is already packed. + if output_dim is None: + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + current_shard_offset = 0 + shard_offsets = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # Special case for Quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + assert loaded_shard_id < len(self.output_sizes) + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + # Special case for quantization. + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + # Special case for Marlin. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + use_bitsandbytes = getattr(param, "use_bitsandbytes", False) + if use_bitsandbytes: + shard_size = loaded_weight.shape[output_dim] + shard_offset = loaded_weight.shape[output_dim] * \ + loaded_shard_id + + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + # Special case for AQLM codebooks. + elif is_metadata: + # metadata indicates fixed size concatenated along dim 0 + shard_size = loaded_weight.shape[0] + shard_offset = loaded_shard_id * shard_size + param_data = param_data.narrow(0, shard_offset, shard_size) + + # If a param_shard_splitter is defined by the LinearMethod, use it. + elif param_shard_splitter is not None: + logical_widths = getattr(param, "logical_widths", None) + param_data, loaded_weight = param_shard_splitter( + param_data, loaded_weight, loaded_shard_id, logical_widths) + + # Special case for Fp8 scales. + elif fp8_scales_shard_indexer is not None: + param_data, loaded_weight = fp8_scales_shard_indexer( + param_data, loaded_weight, loaded_shard_id) + + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions.") + + if fp8_scales_shard_indexer is None: + if len(param_data.shape) == 0: + param_data = param_data.reshape(1) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) diff --git a/vllm/liquid/model_executor/layers/vocab_parallel_embedding.py b/vllm/liquid/model_executor/layers/vocab_parallel_embedding.py index a9ccfaf4321b4..0337015b4e60a 100644 --- a/vllm/liquid/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/liquid/model_executor/layers/vocab_parallel_embedding.py @@ -394,9 +394,11 @@ def __init__(self, bias: bool = False, params_dtype: Optional[torch.dtype] = None, org_num_embeddings: Optional[int] = None, - padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + shard_ids: List[int] = [0], + total_num_shards: int = 1,): super().__init__(num_embeddings, embedding_dim, params_dtype, - org_num_embeddings, padding_size) + org_num_embeddings, padding_size, shard_ids, total_num_shards) if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, diff --git a/vllm/liquid/sharded_parameter.py b/vllm/liquid/sharded_parameter.py index 24051faffabb0..423cb5e99b616 100644 --- a/vllm/liquid/sharded_parameter.py +++ b/vllm/liquid/sharded_parameter.py @@ -4,7 +4,7 @@ class ShardedParameter(Parameter): - def __new__(cls, data=None, requires_grad=False, num_shards=1, shard_dim=0, shard_ids=None): + def __new__(cls, data=None, requires_grad=False, num_shards=1, shard_dim=0, shard_ids=None, **kwargs): # Call the __new__ method of the Parameter class instance = super(ShardedParameter, cls).__new__(cls, data, requires_grad) return instance @@ -35,62 +35,74 @@ def __init__(self, pass - def _get_shard(self, tensor: torch.Tensor, shard_id: int) -> torch.Tensor: + def _get_shard(self, tensor: torch.Tensor, shard_id: int, shard_size: Optional[int]=None) -> torch.Tensor: + if shard_size is None: + shard_size = self.shard_size index = self.shard_ids.index(shard_id) - start_index = index * self.shard_size + start_index = index * shard_size - shard = tensor.narrow(self.shard_dim, start_index, self.shard_size) + shard = tensor.narrow(self.shard_dim, start_index, shard_size) return shard - def _get_shards(self, tensor: torch.Tensor, start_shard_id: int, end_shard_id: int) -> torch.Tensor: + def _get_shards(self, tensor: torch.Tensor, start_shard_id: int, end_shard_id: int, shard_size: Optional[int]=None) -> torch.Tensor: + if shard_size is None: + shard_size = self.shard_size index = self.shard_ids.index(start_shard_id) - start_index = index*self.shard_size + start_index = index*shard_size - shards = tensor.narrow(self.shard_dim, start_index, self.shard_size*(end_shard_id - start_shard_id)) + shards = tensor.narrow(self.shard_dim, start_index, shard_size*(end_shard_id - start_shard_id)) return shards - def get_shard(self, shard_id: int) -> torch.Tensor: + def get_shard(self, shard_id: int, shard_size: Optional[int] = None) -> torch.Tensor: if shard_id not in self.shard_ids: raise ValueError(f"shard_id: {shard_id} not in self.shard_ids") - shard = self._get_shard(self.data, shard_id) + shard = self._get_shard(self.data, shard_id, shard_size) return shard - def get_shards(self, start_shard_id: int, end_shard_id: int) -> torch.Tensor: - shards = self._get_shards(self.data, start_shard_id, end_shard_id) + def get_shards(self, start_shard_id: int, end_shard_id: int, shard_size: Optional[int] = None) -> torch.Tensor: + shards = self._get_shards(self.data, start_shard_id, end_shard_id, shard_size) return shards - def _delete_shards(self, tensor: torch.Tensor, start_shard_id: int, end_shard_id: int) -> torch.Tensor: + def _delete_shards(self, tensor: torch.Tensor, start_shard_id: int, end_shard_id: int, shard_size: Optional[int]=None) -> torch.Tensor: index = self.shard_ids.index(start_shard_id) - - start_index = index * self.shard_size + if shard_size is None: + shard_size = self.shard_size + start_index = index * shard_size before_shard = tensor.narrow(self.shard_dim, 0, start_index) - after_shard = tensor.narrow(self.shard_dim, start_index + self.shard_size*(end_shard_id - start_shard_id), tensor.size(self.shard_dim) - start_index - self.shard_size*(end_shard_id - start_shard_id)) + after_shard = tensor.narrow(self.shard_dim, start_index + shard_size*(end_shard_id - start_shard_id), tensor.size(self.shard_dim) - start_index - shard_size*(end_shard_id - start_shard_id)) new_data = torch.cat([before_shard, after_shard], dim=self.shard_dim) return new_data - def _delete_shard(self, tensor: torch.Tensor, shard_id: int) -> torch.Tensor: + def _delete_shard(self, tensor: torch.Tensor, shard_id: int, shard_size: Optional[int]=None) -> torch.Tensor: index = self.shard_ids.index(shard_id) - - start_index = index * self.shard_size + if shard_size is None: + shard_size = self.shard_size + start_index = index * shard_size # Create views of the tensor parts before and after the shard before_shard = tensor.narrow(self.shard_dim, 0, start_index) - after_shard = tensor.narrow(self.shard_dim, start_index + self.shard_size, tensor.size(self.shard_dim) - start_index - self.shard_size) + after_shard = tensor.narrow(self.shard_dim, start_index + shard_size, tensor.size(self.shard_dim) - start_index - shard_size) # Concatenate the views to form a new tensor new_data = torch.cat([before_shard, after_shard], dim=self.shard_dim) del before_shard, after_shard return new_data - def delete_shards(self, start_shard_id: int, end_shard_id: int) -> None: - new_data = self._delete_shards(self.data, start_shard_id, end_shard_id) + def delete_shards(self, start_shard_id: int, end_shard_id: int, shard_size: Optional[int] = None) -> None: + new_data = self._delete_shards(self.data, start_shard_id, end_shard_id, shard_size) self.data = new_data for shard_id in range(start_shard_id, end_shard_id): index = self.shard_ids.index(shard_id) self.shard_ids.pop(index) + def delete_shard_indexs(self, start_shard_id: int, end_shard_id: int): + for shard_id in range(start_shard_id, end_shard_id): + index = self.shard_ids.index(shard_id) + self.shard_ids.pop(index) + + def delete_shard(self, shard_id: int) -> None: @@ -172,37 +184,70 @@ def __init__(self, shard_dim: int = 0, shard_ids : Optional[List[int]] = None, requires_grad: bool = False, + num_heads_ratio: int = 1, + num_kv_heads_ratio: int = 1, ): super().__init__(data, num_shards, shard_dim, shard_ids) self.requires_grad = requires_grad - self.shard_size = self.shard_size // 3 - assert self.size(shard_dim) % 3 == 0, f"QKV parameter must have a length divisible by 3 along dim: {shard_dim}" + import math + d = math.gcd(num_heads_ratio, num_kv_heads_ratio) + num_heads_ratio = num_heads_ratio // d + num_kv_heads_ratio = num_kv_heads_ratio // d + self._num_heads_ratio = num_heads_ratio + self._num_kv_heads_ratio = num_kv_heads_ratio + # assert self.size(shard_dim) % 3 == 0, f"QKV parameter must have a length divisible by 3 along dim: {shard_dim}" # qkv_shard_size = self.size(shard_dim) // 3 # self.q_data = self.narrow(shard_dim, 0, qkv_shard_size) # self.k_data = self.narrow(shard_dim, qkv_shard_size, qkv_shard_size) # self.v_data = self.narrow(shard_dim, 2*qkv_shard_size, qkv_shard_size) # self.q_data, self.k_data, self.v_data = data.chunk(3, shard_dim) + + def get_qkv_size(self, siz: int) -> int: + q_size = siz // (self._num_heads_ratio + 2 * self._num_kv_heads_ratio) * self._num_heads_ratio + k_size = siz // (self._num_heads_ratio + 2 * self._num_kv_heads_ratio) * self._num_kv_heads_ratio + v_size = siz // (self._num_heads_ratio + 2 * self._num_kv_heads_ratio) * self._num_kv_heads_ratio + return q_size, k_size, v_size + + def customize_chunk(self, data: torch.Tensor) -> torch.Tensor: + shape = list(data.shape) + if self.shard_dim >= len(shape): + raise ValueError(f"shard_dim: {self.shard_dim} is larger than the number of dimensions of the tensor: {len(shape)}") + siz = shape[self.shard_dim] + if siz % (self._num_heads_ratio + 2 * self._num_kv_heads_ratio) != 0: + raise ValueError(f"QKV parameter must have a length divisible by {self._num_heads_ratio + 2 * self._num_kv_heads_ratio} along dim: {self.shard_dim}") + q_size, k_size, v_size = self.get_qkv_size(siz) + try: + q_tensor = torch.narrow(data,self.shard_dim, 0, q_size) + k_tensor = torch.narrow(data,self.shard_dim, q_size, k_size) + v_tensor = torch.narrow(data,self.shard_dim, q_size + k_size, v_size) + except Exception as e: + raise ValueError(f"shape: {data.shape}, dim: {self.shard_dim}, q_size: {q_size}, k_size: {k_size}, v_size: {v_size} ,Error in customizing chunk: {e}") + return q_tensor, k_tensor, v_tensor + def get_shard(self, shard_id: int) -> torch.Tensor: - q_data, k_data, v_data = self.data.chunk(3, self.shard_dim) - q_shard = self._get_shard(q_data, shard_id) - k_shard = self._get_shard(k_data, shard_id) - v_shard = self._get_shard(v_data, shard_id) + q_data, k_data, v_data = self.customize_chunk(self.data) + q_shard_size, k_shard_size, v_shard_size = self.get_qkv_size(self.shard_size) + q_shard = self._get_shard(q_data, shard_id, q_shard_size) + k_shard = self._get_shard(k_data, shard_id, k_shard_size) + v_shard = self._get_shard(v_data, shard_id, v_shard_size) return q_shard, k_shard, v_shard def get_shards(self, start_shard_id: int, end_shard_id: int) -> torch.Tensor: - q_data, k_data, v_data = self.data.chunk(3, self.shard_dim) - q_shards = self._get_shards(q_data, start_shard_id, end_shard_id) - k_shards = self._get_shards(k_data, start_shard_id, end_shard_id) - v_shards = self._get_shards(v_data, start_shard_id, end_shard_id) + q_data, k_data, v_data = self.customize_chunk(self.data) + q_shard_size, k_shard_size, v_shard_size = self.get_qkv_size(self.shard_size) + q_shards = self._get_shards(q_data, start_shard_id, end_shard_id, q_shard_size) + k_shards = self._get_shards(k_data, start_shard_id, end_shard_id, k_shard_size) + v_shards = self._get_shards(v_data, start_shard_id, end_shard_id, v_shard_size) return q_shards, k_shards, v_shards def delete_shard(self, shard_id: int) -> None: - q_data, k_data, v_data = self.data.chunk(3, self.shard_dim) - q_data = self._delete_shard(q_data, shard_id) - k_data = self._delete_shard(k_data, shard_id) - v_data = self._delete_shard(v_data, shard_id) + q_data, k_data, v_data = self.customize_chunk(self.data) + q_shard_size, k_shard_size, v_shard_size = self.get_qkv_size(self.shard_size) + q_data = self._delete_shard(q_data, shard_id, q_shard_size) + k_data = self._delete_shard(k_data, shard_id, k_shard_size) + v_data = self._delete_shard(v_data, shard_id, v_shard_size) new_data = torch.cat([q_data, k_data, v_data], dim=self.shard_dim) self.data = new_data @@ -213,10 +258,11 @@ def delete_shard(self, shard_id: int) -> None: self.shard_ids.pop(index) def delete_shards(self, start_shard_id: int, end_shard_id: int): - q_data, k_data, v_data = self.data.chunk(3, self.shard_dim) - q_data = self._delete_shards(q_data, start_shard_id, end_shard_id) - k_data = self._delete_shards(k_data, start_shard_id, end_shard_id) - v_data = self._delete_shards(v_data, start_shard_id, end_shard_id) + q_data, k_data, v_data = self.customize_chunk(self.data) + q_shard_size, k_shard_size, v_shard_size = self.get_qkv_size(self.shard_size) + q_data = self._delete_shards(q_data, start_shard_id, end_shard_id, q_shard_size) + k_data = self._delete_shards(k_data, start_shard_id, end_shard_id, k_shard_size) + v_data = self._delete_shards(v_data, start_shard_id, end_shard_id, v_shard_size) new_data = torch.cat([q_data, k_data, v_data], dim=self.shard_dim) self.data = new_data # del q_data, k_data, v_data @@ -233,7 +279,7 @@ def append_shard(self, shard_id: int, q_shard: torch.Tensor, k_shard: torch.Tens raise ValueError(f"shard_id: {shard_id} is already in self.shard_ids") - q_data, k_data, v_data = self.data.chunk(3, self.shard_dim) + q_data, k_data, v_data = self.customize_chunk(self.data) q_data = self._append_shard(q_data, q_shard) k_data = self._append_shard(k_data, k_shard) v_data = self._append_shard(v_data, v_shard) @@ -244,22 +290,6 @@ def append_shard(self, shard_id: int, q_shard: torch.Tensor, k_shard: torch.Tens self.shard_ids.append(shard_id) - # def append_shards(self, start_shard_id: int, end_shard_id: int,q_shard: torch.Tensor, k_shard: torch.Tensor, v_shard: torch.Tensor) -> None: - - - # q_data, k_data, v_data = self.data.chunk(3, self.shard_dim) - # q_data = self._append_shard(q_data, q_shard) - # k_data = self._append_shard(k_data, k_shard) - # v_data = self._append_shard(v_data, v_shard) - - # self.data = torch.cat([q_data, k_data, v_data], dim=self.shard_dim) - # del q_data, k_data, v_data - # # del self.q_data, self.k_data, self.v_data - # # self.q_data, self.k_data, self.v_data = self.data.chunk(3) - - # for shard_id in range(start_shard_id, end_shard_id): - # self.shard_ids.append(shard_id) - def append_shards(self, start_shard_id: int, end_shard_id: int,q_shard: torch.Tensor, k_shard: torch.Tensor, v_shard: torch.Tensor) -> None: if self.shard_ids == []: self.data = torch.cat([q_shard, k_shard, v_shard]) @@ -280,12 +310,67 @@ def extend_and_load_shard(self, q_shard: torch.Tensor, k_shard: torch.Tensor, v_ dtype=self.data.dtype, device=self.data.device, ) - q_data, k_data, v_data =self.data.chunk(chunks=3, dim=self.shard_dim) - new_q_data, new_k_data, new_v_data = new_data.chunk(chunks=3, dim=self.shard_dim) + q_data, k_data, v_data = self.customize_chunk(self.data) + new_q_data, new_k_data, new_v_data = self.customize_chunk(new_data) self._in_place_cat(new_q_data, q_data, q_shard) self._in_place_cat(new_k_data, k_data, k_shard) self._in_place_cat(new_v_data, v_data, v_shard) # self.q_data, self.k_data, self.v_data = self.data.chunk(3, dim=self.shard_dim) + self.data = new_data + +# TODO: current is llama3 only +class GateUpShardedParameter(ShardedParameter): + def __init__(self, + data: torch.Tensor, + num_shards: int = 1, + shard_dim: int = 0, + shard_ids: Optional[List[int]] = None, + requires_grad: bool = False, + ): + super().__init__(data, num_shards, shard_dim, shard_ids) + self.requires_grad = requires_grad + self.shard_size = self.shard_size // 2 + assert self.size(shard_dim) % 2 == 0, f"merged column parameter must have a length divisible by 2 along dim: {shard_dim}" + + def get_shards(self, start_shard_id: int, end_shard_id: int) -> torch.Tensor: + gate_data, up_data = self.data.chunk(2, self.shard_dim) + gate_shards = self._get_shards(gate_data, start_shard_id, end_shard_id) + up_shards = self._get_shards(up_data, start_shard_id, end_shard_id) + return gate_shards, up_shards + + def delete_shards(self, start_shard_id: int, end_shard_id: int) -> None: + gate_data, up_data = self.data.chunk(2, self.shard_dim) + gate_data = self._delete_shards(gate_data, start_shard_id, end_shard_id) + up_data = self._delete_shards(up_data, start_shard_id, end_shard_id) + new_data = torch.cat([gate_data, up_data], dim=self.shard_dim) + self.data = new_data + for shard_id in range(start_shard_id, end_shard_id): + index = self.shard_ids.index(shard_id) + self.shard_ids.pop(index) + + def append_shards(self, start_shard_id: int, end_shard_id: int, gate_shard: torch.Tensor, up_shard: torch.Tensor) -> None: + if self.shard_ids == []: + self.data = torch.cat([gate_shard, up_shard]) + else: + self.extend_and_load_shard(gate_shard, up_shard) + for shard_id in range(start_shard_id, end_shard_id): + self.shard_ids.append(shard_id) + + def extend_and_load_shard(self, gate_shard: torch.Tensor, up_shard: torch.Tensor) -> None: + shape = list(self.data.shape) + shape[self.shard_dim] = shape[self.shard_dim] * 2 + new_shape = torch.Size(shape) + new_data = torch.empty( + size=new_shape, + dtype=self.data.dtype, + device=self.data.device, + ) + gate_data, up_data = self.data.chunk(chunks=2, dim=self.shard_dim) + new_gate_data, new_up_data = new_data.chunk(chunks=2, dim=self.shard_dim) + + self._in_place_cat(new_gate_data, gate_data, gate_shard) + self._in_place_cat(new_up_data, up_data, up_shard) + self.data = new_data \ No newline at end of file diff --git a/vllm/liquid/utils.py b/vllm/liquid/utils.py index 08b000ad2d95c..8f84bdd51e86c 100644 --- a/vllm/liquid/utils.py +++ b/vllm/liquid/utils.py @@ -133,7 +133,7 @@ def get_tensor_num_bytes(tensor: torch.Tensor) -> int: total_memory_bytes = (num_elements * bits) // 8 return total_memory_bytes -DEBUG_MODE = False +DEBUG_MODE = True def get_cuda_mem_info(device: int=0) -> str: # torch.cuda.set_device(f"cuda:{device}") diff --git a/vllm/model_executor/models/llama_liquid.py b/vllm/model_executor/models/llama_liquid.py index d83ee9a201c0b..7bd62383eca5c 100644 --- a/vllm/model_executor/models/llama_liquid.py +++ b/vllm/model_executor/models/llama_liquid.py @@ -33,21 +33,32 @@ get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, +# from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, +# QKVParallelLinear, +# RowParallelLinear) +from vllm.liquid.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) + from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler +#TODO from vllm.model_executor.layers.vocab_parallel_embedding import ( - DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) + DEFAULT_VOCAB_PADDING_SIZE) +from vllm.liquid.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) + from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplerOutput from vllm.utils import is_hip, print_warning_once +from vllm.config import LiquidConfig +from vllm.liquid.sharded_parameter import ShardedParameter, QKVShardedParameter, GateUpShardedParameter +from vllm.liquid.utils import get_cuda_mem_info class LlamaMLP(nn.Module): @@ -59,17 +70,23 @@ def __init__( hidden_act: str, quant_config: Optional[QuantizationConfig] = None, bias: bool = False, + shard_ids: List[int] = [0], + total_num_shards: int = 1, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( input_size=hidden_size, output_sizes=[intermediate_size] * 2, bias=bias, - quant_config=quant_config) + quant_config=quant_config, + shard_ids=shard_ids, + total_num_shards=total_num_shards,) self.down_proj = RowParallelLinear(input_size=intermediate_size, output_size=hidden_size, bias=bias, - quant_config=quant_config) + quant_config=quant_config, + shard_ids= shard_ids, + total_num_shards=total_num_shards) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -95,30 +112,21 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, cache_config: Optional[CacheConfig] = None, + shard_ids: List[int] = [0], + total_num_shards: int = 1, ) -> None: super().__init__() + self._shard_ids = shard_ids + self.total_num_shards = total_num_shards + self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: - # Number of KV heads is greater than TP size, so we partition - # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 - else: - # Number of KV heads is less than TP size, so we replicate - # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = hidden_size // self.total_num_heads - self.q_size = self.num_heads * self.head_dim - self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - + self.scaling = self.head_dim**-0.5 + self.update_param() self.qkv_proj = QKVParallelLinear( hidden_size=hidden_size, head_size=self.head_dim, @@ -126,12 +134,16 @@ def __init__( total_num_kv_heads=self.total_num_kv_heads, bias=bias, quant_config=quant_config, + shard_ids=shard_ids, + total_num_shards=total_num_shards ) self.o_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, output_size=hidden_size, bias=bias, quant_config=quant_config, + shard_ids=shard_ids, + total_num_shards=total_num_shards ) self.rotary_emb = get_rope( @@ -146,7 +158,16 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + shard_ids=shard_ids, + total_num_shards=total_num_shards) + + def update_param(self): + num_shards = len(self._shard_ids) + self.num_heads = self.total_num_heads * num_shards // self.total_num_shards + self.num_kv_heads = max(1, self.total_num_kv_heads * num_shards // self.total_num_shards) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim def forward( self, @@ -170,6 +191,8 @@ def __init__( config: LlamaConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + shard_ids: List[int] = [0], + total_num_shards: int = 1, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -196,6 +219,8 @@ def __init__( quant_config=quant_config, bias=attention_bias, cache_config=cache_config, + shard_ids=shard_ids, + total_num_shards=total_num_shards, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, @@ -203,6 +228,8 @@ def __init__( hidden_act=config.hidden_act, quant_config=quant_config, bias=getattr(config, "mlp_bias", False), + shard_ids=shard_ids, + total_num_shards=total_num_shards, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -246,6 +273,8 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + shard_ids: List[int] = [0], + total_num_shards: int = 1, ) -> None: super().__init__() self.config = config @@ -258,11 +287,15 @@ def __init__( self.vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, + shard_ids=shard_ids, + total_num_shards=total_num_shards, ) self.layers = nn.ModuleList([ LlamaDecoderLayer(config=config, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + shard_ids=shard_ids, + total_num_shards=total_num_shards) for idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -334,13 +367,21 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + liquid_config: Optional[LiquidConfig] = None, + shard_ids: List[int] = [0], ) -> None: super().__init__() + total_num_shards = 1 if liquid_config is None else liquid_config.liquid_total_num_shards self.config = config + self.shard_ids = shard_ids + self.liquid_config = liquid_config self.model = LlamaModel(config, cache_config, quant_config, - lora_config=lora_config) + lora_config=lora_config, + shard_ids=shard_ids, + total_num_shards=total_num_shards) + self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size @@ -352,7 +393,10 @@ def __init__( # We need bigger padding if using lora for kernel # compatibility if not lora_config else lora_config.lora_vocab_padding_size, + shard_ids=shard_ids, + total_num_shards=total_num_shards, ) + # seems not typing, let's just ignore it if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight @@ -360,6 +404,10 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) self.sampler = Sampler() + if self.liquid_config is not None: + self.total_num_shards = self.liquid_config.liquid_total_num_shards + + def forward( self, @@ -385,6 +433,11 @@ def sample( ) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + + def check_weights(self): + for name, weight in self.model.named_parameters(): + print(f"Name: {name}, Weight: {weight}") + print(f"\n\n") def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ @@ -441,6 +494,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # factors (or else raise an exception). Thus, handled exceptions should # make sure to leave KV cache scale factors in a known good (dummy) state def load_kv_cache_scales(self, quantization_param_path: str) -> None: + raise NotImplementedError tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() for layer_idx, scaling_factor in kv_cache_scales_loader( @@ -460,3 +514,177 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") + + def named_sharded_parameters(self): + for name, param in self.named_parameters(): + if hasattr(param, "shard_ids"): + yield name, param + + def sorted_named_parameters(self, descending: bool = True): + # sort the parameters first to avoid a memory fragmentation + # Get the named parameters of the model + named_params = list(self.named_parameters()) + + # Sort the named parameters based on the number of elements (numel()) in descending order + sorted_named_params = sorted(named_params, key=lambda x: x[1].numel(), reverse=descending) + + return sorted_named_params + + + + def get_shards_weights(self, shard_ids: List[int], only_sharded: bool = True) -> Dict[str, torch.Tensor]: + results = {} + if len(shard_ids) == 1: + start_shard_id = shard_ids[0] + end_shard_id = start_shard_id+1 + else: + start_shard_id = shard_ids[0] + end_shard_id = shard_ids[-1] + 1 + + for name, param in self.sorted_named_parameters(): + if isinstance(param, QKVShardedParameter): + q_shard, k_shard, v_shard = param.get_shards(start_shard_id, end_shard_id) + results[f"{name}_q"] = q_shard + results[f"{name}_k"] = k_shard + results[f"{name}_v"] = v_shard + elif isinstance(param, GateUpShardedParameter): + gate_shard, up_shard = param.get_shards(start_shard_id, end_shard_id) + results[f"{name}_gate"] = gate_shard + results[f"{name}_up"] = up_shard + elif isinstance(param, ShardedParameter): + results[name] = param.get_shards(start_shard_id, end_shard_id) + else: + if not only_sharded: + results[name] = param + # sort the results to reduce memory fragmentation + return results + + def delete_shards(self, shard_ids: List[int]) -> None: + + if len(shard_ids) == 1: + start_shard_id = shard_ids[0] + end_shard_id = start_shard_id+1 + else: + start_shard_id = shard_ids[0] + end_shard_id = shard_ids[-1] + 1 + + shard_dim = self.lm_head.weight.shard_dim + lm_head_first_half, lm_head_last_half = self.lm_head.weight.chunk(2, shard_dim) + embed_token_first_half, _ = self.model.embed_tokens.weight.chunk(2, shard_dim) + lm_head_last_half.copy_(embed_token_first_half) + del embed_token_first_half + self.lm_head.weight.data = lm_head_first_half + self.model.embed_tokens.weight.data = lm_head_last_half + + self.lm_head.weight.delete_shard_indexs(start_shard_id, end_shard_id) + self.model.embed_tokens.weight.delete_shard_indexs(start_shard_id, end_shard_id) + + # print(f"Before deleting shards, {get_cuda_mem_info()}") + # torch.cuda.memory._record_memory_history(max_entries=100000, context="all") + for name, param in self.sorted_named_parameters(True): + if hasattr(param, "num_shards"): + if name in ['lm_head.weight', 'model.embed_tokens.weight']: + continue + param.delete_shards(start_shard_id, end_shard_id) + # torch.cuda.empty_cache() + # torch.cuda.empty_cache() + # print(f"After deleting shards, {get_cuda_mem_info()}") + # torch.cuda.memory._dump_snapshot(f"./torch_mem_dump.pickle") + # torch.cuda.memory._record_memory_history(enabled=None) + + for layer in self.model.layers: + for shard_id in range(start_shard_id, end_shard_id): + layer.self_attn.attn.delete_shard(shard_id) + layer.self_attn.update_param() + + for shard_id in range(start_shard_id, end_shard_id): + index = self.shard_ids.index(shard_id) + self.shard_ids.pop(index) + self.model.embed_tokens.update_sharded_indices(shard_ids=self.shard_ids, total_num_shards=self.total_num_shards) + for layer in self.model.layers: + layer.self_attn.update_param() + + + def load_shards_weights(self, shard_ids: List[int], shards_weights: Dict[str, torch.Tensor]): + if len(shard_ids) == 1: + start_shard_id = shard_ids[0] + end_shard_id = start_shard_id+1 + else: + start_shard_id = shard_ids[0] + end_shard_id = shard_ids[-1] + 1 + shard_id = shard_ids[0] + assert shard_id in self.shard_ids, f"{shard_id} not in the model" + for name, param in self.sorted_named_parameters(): + if isinstance(param, QKVShardedParameter): + q_shard = shards_weights[f"{name}_q"] + k_shard = shards_weights[f"{name}_k"] + v_shard = shards_weights[f"{name}_v"] + # print(param.requires_grad) + # q_data, k_data, v_data = param.chunk(3, dim=param.shard_dim) + q_data, k_data, v_data = param.customize_chunk(param.data) + q_data.copy_(q_shard) + k_data.copy_(k_shard) + v_data.copy_(v_shard) + elif isinstance(param, GateUpShardedParameter): + gate_shard = shards_weights[f"{name}_gate"] + up_shard = shards_weights[f"{name}_up"] + gate_data, up_data = param.chunk(2, dim=param.shard_dim) + gate_data.copy_(gate_shard) + up_data.copy_(up_shard) + else: + param.data.copy_(shards_weights[name]) + # if name in shards_weights.keys(): + # param.data.copy_(shards_weights[name]) + self.model.embed_tokens.update_sharded_indices(shard_ids=self.shard_ids, total_num_shards=self.total_num_shards) + # self.shard_ids.append(shard_id) + for layer in self.model.layers: + layer.self_attn.update_param() + + def append_shards_weights(self, shard_ids: List[int], shards_weights: Dict[str, torch.Tensor]): + + if len(shard_ids) == 1: + start_shard_id = shard_ids[0] + end_shard_id = start_shard_id+1 + else: + start_shard_id = shard_ids[0] + end_shard_id = shard_ids[-1]+1 + # print(f"Before entering for loop, {get_cuda_mem_info()}") + # torch.cuda.memory._record_memory_history(max_entries=100000, context="all") + with torch.no_grad(): + for name, param in self.sorted_named_parameters(False): + if isinstance(param, QKVShardedParameter): + q_shard = shards_weights[f"{name}_q"] + k_shard = shards_weights[f"{name}_k"] + v_shard = shards_weights[f"{name}_v"] + param.append_shards(start_shard_id, end_shard_id ,q_shard, k_shard, v_shard) + # param.extend_and_load_shard(q_shard, k_shard, v_shard) + del q_shard, k_shard, v_shard + del shards_weights[f"{name}_q"], shards_weights[f"{name}_k"], shards_weights[f"{name}_v"] + # torch.cuda.empty_cache() + elif isinstance(param, GateUpShardedParameter): + gate_shard = shards_weights[f"{name}_gate"] + up_shard = shards_weights[f"{name}_up"] + param.append_shards(start_shard_id, end_shard_id ,gate_shard, up_shard) + # param.extend_and_load_shard(gate_shard, up_shard) + del gate_shard, up_shard + del shards_weights[f"{name}_gate"], shards_weights[f"{name}_up"] + # torch.cuda.empty_cache() + elif isinstance(param, ShardedParameter): + param.append_shards(start_shard_id, end_shard_id ,shards_weights[name]) + # param.extend_and_load_shard(shard_data=shards_weights[name]) + del shards_weights[name] + torch.cuda.empty_cache() + # print(f"After exit for loop, {get_cuda_mem_info()}") + # torch.cuda.memory._dump_snapshot(f"./torch_mem_dump.pickle") + # torch.cuda.memory._record_memory_history(enabled=None) + + for layer in self.model.layers: + for shard_id in range(start_shard_id, end_shard_id): + layer.self_attn.attn.append_shard(shard_id) + + for shard_id in range(start_shard_id, end_shard_id): + self.shard_ids.append(shard_id) + self.model.embed_tokens.update_sharded_indices(shard_ids=self.shard_ids, total_num_shards=self.total_num_shards) + for layer in self.model.layers: + layer.self_attn.update_param() + diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 4791942897419..591fabe4b47c5 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -30,7 +30,7 @@ from vllm.model_executor.model_loader.utils import (get_model_architecture, set_default_torch_dtype) import time -from vllm.liquid.sharded_parameter import QKVShardedParameter, ShardedParameter +from vllm.liquid.sharded_parameter import QKVShardedParameter, ShardedParameter, GateUpShardedParameter logger = init_logger(__name__) @@ -173,6 +173,9 @@ def recv_shards(self, shard_ids: List[int], src: int, only_sharded: bool = False param_names.append(f"{name}_q") param_names.append(f"{name}_k") param_names.append(f"{name}_v") + elif isinstance(param, GateUpShardedParameter): + param_names.append(f"{name}_gate") + param_names.append(f"{name}_up") elif isinstance(param, ShardedParameter): param_names.append(name) else: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 4c47a33191f75..02236e72fc49c 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -225,7 +225,7 @@ def liquid_kv_cache(self, shard_ids: List[int], src: int, dst: int, load_kv_cach del weight del shards_cache torch.cuda.empty_cache() - logger.info(f"After appending kvc shards, {get_cuda_mem_info(self.rank)}") + logger.info(f"After cleaning received kvc shards, {get_cuda_mem_info(self.rank)}") def save_sharded_state(