Skip to content
This repository has been archived by the owner on Jul 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request vllm-project#1 from DeepAuto-AI/geon-dev
Browse files Browse the repository at this point in the history
merge code
  • Loading branch information
mujjingun authored Mar 10, 2024
2 parents 1c03585 + 7f2a7d8 commit d9d746e
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 83 deletions.
232 changes: 156 additions & 76 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512

from timber.models.timber_attention.attention1_block_gpu import paged_timber_attention

from timber.models.timber_attention.attention1_block_gpu import (
paged_timber_attention,
timber_attention
)
from vllm.transformers_utils import config as vllm_transformers_config
from timber.utils import get_bench
BENCHMARK_ITERATION = 0

class PagedAttention(nn.Module):
Expand All @@ -44,6 +48,7 @@ def __init__(
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
layer_index: Optional[int] = None,
) -> None:
super().__init__()
self.num_heads = num_heads
Expand All @@ -61,6 +66,8 @@ def __init__(
if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")

self.layer_index = layer_index

def forward(
self,
Expand Down Expand Up @@ -106,88 +113,160 @@ def forward(
input_metadata.slot_mapping.flatten(),
input_metadata.kv_cache_dtype,
)

hip_k = int(os.environ.get('HIP_K', '1024'))

if input_metadata.is_prompt:
# Prompt run.
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :, None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# normal attention
if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0):
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
BENCHMARK_PROMPT_ATTENTION = os.environ.get('BENCHMARK_PAGED_ATTENTION', '0') == '1'
backend = os.environ.get('PROMPT_ATTENTION_BACKEND', 'vllm')
is_normal_attention = (key_cache is None) or (value_cache is None) or (input_metadata.block_tables.numel() == 0)
if backend == 'vllm':
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(
query.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
query.shape[-1],
)
key = key[:, :, None, :]\
.expand(
key.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1]
)
value = value[:, :, None, :]\
.expand(
value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1]
)
# normal attention
if is_normal_attention:
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
if input_metadata.attn_bias is None:
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)

# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
attn_bias = BlockDiagonalCausalMask.from_seqlens(
[seq_len] * batch_size)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
input_metadata.attn_bias = attn_bias
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
else:
input_metadata.attn_bias = _make_alibi_bias(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))

# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
query = query.unsqueeze(0)
key = key.unsqueeze(0)
value = value.unsqueeze(0)
if BENCHMARK_PROMPT_ATTENTION:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()

out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = out.view_as(query)

if BENCHMARK_PROMPT_ATTENTION:
end.record()
torch.cuda.synchronize()
print(backend, start.elapsed_time(end), output.shape, end='\n')
else:
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))

out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
# prefix-enabled attention
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
key_cache,
value_cache,
input_metadata.block_tables, # [BS, max_block_per_request]
input_metadata.start_loc,
input_metadata.prompt_lens,
input_metadata.context_lens,
input_metadata.max_seq_len,
getattr(self, "alibi_slopes", None),
)
elif backend == 'timber':
# timber support MQA/GQA
warnings.warn('prompt attention backend is timber')

TDST, H, HID = query.shape
TSRC, H_KV, _HID = key.shape
assert key.shape[:-1] == value.shape[:-1]
assert HID == _HID

query = query.permute(1, 0, 2)
key = key.permute(1, 0, 2)
value = value.permute(1, 0, 2)

if BENCHMARK_PROMPT_ATTENTION:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()

assert input_metadata.attn_bias is None
assert self.alibi_slopes is None

output, _ = timber_attention(
q=query * self.scale,
k=key,
v=value,
attention_mask=None,
mask_k=hip_k,
block_size_q=32,
block_size_k=2,
)
output = out.view_as(query)

output = output.permute(1, 0, 2)
output = output.view(
1,
TDST,
H,
HID,
).contiguous()

if BENCHMARK_PROMPT_ATTENTION:
end.record()
torch.cuda.synchronize()
print(backend, start.elapsed_time(end), output.shape, end='\n')
else:
# prefix-enabled attention
output = torch.empty_like(query)
context_attention_fwd(
query,
key,
value,
output,
key_cache,
value_cache,
input_metadata.block_tables, # [BS, max_block_per_request]
input_metadata.start_loc,
input_metadata.prompt_lens,
input_metadata.context_lens,
input_metadata.max_seq_len,
getattr(self, "alibi_slopes", None),
)

raise Exception(backend)
else:
# Decoding run.
BENCHMARK_PAGED_ATTENTION = os.environ.get('BENCHMARK_PAGED_ATTENTION', '0') == '1'

# print(f'[{os.getpid()}, {self.layer_index}] query_size: {query.shape}, block_table: {input_metadata.block_tables.shape}[{input_metadata.max_context_len}/{input_metadata.max_seq_len}]')

if BENCHMARK_PAGED_ATTENTION:
warnings.warn(f'query_size: {query.shape}, block_table: {input_metadata.block_tables.shape}[{input_metadata.max_context_len}/{input_metadata.max_seq_len}]')
warnings.warn(f'query_size: {query.shape}({query.dtype}), block_table: {input_metadata.block_tables.shape}[{input_metadata.max_context_len}/{input_metadata.max_seq_len}]')
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
Expand All @@ -203,9 +282,9 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
)
)
elif backend == 'timber':
warnings.warn('backend is timber')
warnings.warn('paged attention backend is timber')

output, _ = paged_timber_attention(
q=query,
Expand All @@ -216,9 +295,9 @@ def forward(
context_lens=input_metadata.context_lens,
max_context_len=input_metadata.max_context_len,
attention_mask=None,
mask_k=1024,
mask_k=hip_k,
block_size_q=32,
block_size_k=2,
block_size_q=16
)

N_H, _, HID = output.shape
Expand All @@ -243,11 +322,12 @@ def forward(
"alibi_slopes": self.alibi_slopes,
"output": output,
}, 'cache/llama/vllmout.pth')
print('saved cache/llama/vllmout.pth')

if BENCHMARK_PAGED_ATTENTION:
end.record()
torch.cuda.synchronize()
print(start.elapsed_time(end))
print(f'({backend}) {start.elapsed_time(end)}', end='\r')

# Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size)
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,9 @@ def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
dtype=torch.float,
device=pos_freqs.device
)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
Expand Down
7 changes: 7 additions & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,15 @@ def load_weights(self,
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
# print('vllm.load_weight: ignore', weight_name)
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
print('vllm.load_weight: ignore', name)
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -355,6 +359,9 @@ def load_weights(self,
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
print('vllm.load_weight: ignore', name)
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down
7 changes: 4 additions & 3 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
}

# NOTE: For benchmarking
FORCE_SIGNLE_LAYER = False
FORCE_SIGNLE_LAYER = 0

def get_config(
model: str,
Expand Down Expand Up @@ -44,7 +44,8 @@ def get_config(
config = config_class.from_pretrained(model, revision=revision)

# NOTE: DEBUG
if FORCE_SIGNLE_LAYER:
config.num_hidden_layers = 1
if FORCE_SIGNLE_LAYER > 0:
assert isinstance(FORCE_SIGNLE_LAYER, int)
config.num_hidden_layers = FORCE_SIGNLE_LAYER

return config
Loading

0 comments on commit d9d746e

Please sign in to comment.