Skip to content

Commit

Permalink
Fix cuda oom (vllm-project#7)
Browse files Browse the repository at this point in the history
add torch.cuda.empty_cache()
  • Loading branch information
xiangyuT authored Oct 24, 2023
1 parent f78c169 commit 1ab029d
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions vllm/model_executor/models/bigdl_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = self.model.config.torch_dtype
# self.tmp_kv_cache = []

def decode(self, generated_ids: List[int]) -> str:
return self.tokenizer.decode(
Expand Down Expand Up @@ -76,7 +77,6 @@ def forward(
all_decoding = all_decoding and (not seq_group_meta_data.is_prompt)
seq_ids = list(seq_group_meta_data.seq_data.keys())
seq_id = seq_ids[0]
print(seq_id)
cur_seq_ids.append(seq_id)
seq_data = seq_group_meta_data.seq_data[seq_id]

Expand All @@ -93,9 +93,13 @@ def forward(
for seq_group_meta_data in seq_group_meta_data_lists:
seq_ids = list(seq_group_meta_data.seq_data.keys())
seq_id = seq_ids[0]
if kv_cache.get(seq_id) is None:
continue
for i in range(kv_cache_0):
for j in range(kv_cache_1):
bigdl_kv_cache[i][j] = torch.cat((bigdl_kv_cache[i][j], kv_cache[seq_id][i][j]), dim=0).to(dtype = self.dtype)
target_size = (bigdl_kv_cache[i][j].size(0) + kv_cache[seq_id][i][j].size(0),) + kv_cache[seq_id][i][j].size()[1:]
bigdl_kv_cache[i][j].resize_(target_size)
bigdl_kv_cache[i][j][-kv_cache[seq_id][i][j].size(0):] = kv_cache[seq_id][i][j]

bigdl_input_ids = torch.tensor(bigdl_input_ids, device=self.device)
bigdl_position_ids = torch.tensor(bigdl_position_ids, device=self.device)
Expand All @@ -115,12 +119,13 @@ def forward(
"use_cache": True,
"return_dict": True,
}
# kwargs["position_ids"] = position_ids
# pdb.set_trace()
outputs = self.model.forward(**kwargs)
# self.tmp_kv_cache = outputs.past_key_values
index = 0
bigdl_output = []
for seq_id in cur_seq_ids:
# pdb.set_trace()
cur_sampling_params = bigdl_sampling_params[seq_id]
logits_processor = prepare_logits_processor(
cur_sampling_params.temperature, 1,
Expand All @@ -143,16 +148,16 @@ def forward(
kv_cache[seq_id] = [[[] for _ in range(kv_cache_1)] for _ in range(kv_cache_0)]
for i in range(kv_cache_0):
for j in range(kv_cache_1):
kv_cache[seq_id][i][j] = outputs.past_key_values[i][j][index].unsqueeze(0).to(device=self.device,dtype = self.dtype)
kv_cache[seq_id][i][j] = outputs.past_key_values[i][j][index].unsqueeze(0)
index = index + 1

#pdb.set_trace()
torch.cuda.empty_cache()

return bigdl_output

def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
pass

0 comments on commit 1ab029d

Please sign in to comment.