Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial FP8 support #75

Merged
merged 14 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _is_neuron() -> bool:
torch_neuronx_installed = True
try:
subprocess.run(["neuron-ls"], capture_output=True, check=True)
except (FileNotFoundError, PermissionError, subprocess.CalledProcessError):
except (FileNotFoundError, NotADirectoryError, PermissionError, subprocess.CalledProcessError):
torch_neuronx_installed = False
return torch_neuronx_installed or envs.VLLM_BUILD_WITH_NEURON

Expand Down
30 changes: 22 additions & 8 deletions vllm/attention/backends/habana_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Type

import os
import torch
import math
import vllm.hpu.xops as xops
from vllm.hpu import cache_ops, xops
from vllm.hpu.utils import Matmul, Softmax, VLLMKVCache
from vllm.hpu.attn_bias import (AttentionBias,
LowerTriangularMaskWithTensorBias)

Expand Down Expand Up @@ -111,7 +113,7 @@ def __post_init__(self):
self.attn_bias: Optional[List[AttentionBias]] = None


class HabanaAttentionImpl(AttentionImpl):
class HabanaAttentionImpl(AttentionImpl, torch.nn.Module):
"""
If the input tensors contain prompt tokens, the layout is as follows:
|<--------------- num_prefill_tokens ----------------->|
Expand All @@ -137,8 +139,14 @@ def __init__(
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
) -> None:
super(AttentionImpl, self).__init__()
self.num_heads = num_heads
self.head_size = head_size
self.qk_matmul = Matmul()
self.softmax = Softmax()
self.kv_matmul = Matmul()
self.key_cache = VLLMKVCache()
self.value_cache = VLLMKVCache()
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.sliding_window = sliding_window
Expand Down Expand Up @@ -188,11 +196,9 @@ def forward(
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
HabanaPagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache,
attn_metadata.slot_mapping,
attn_metadata.kv_cache_dtype,
attn_metadata.prefill_metadata is not None)
block_indices, block_offset = cache_ops.prepare_to_cache(key_cache, attn_metadata.slot_mapping)
key_cache = self.key_cache(key, key_cache, block_indices, block_offset)
value_cache = self.value_cache(value, value_cache, block_indices, block_offset)

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
Expand All @@ -208,6 +214,9 @@ def forward(
attn_bias=prefill_meta.attn_bias,
p=0.0,
scale=self.scale,
qk_matmul_op=self.qk_matmul,
softmax_op=self.softmax,
kv_matmul_op=self.kv_matmul,
)
output = out.reshape(batch_size, seq_len, hidden_size)
else:
Expand Down Expand Up @@ -237,7 +246,12 @@ def forward(
self.num_kv_heads,
self.scale,
self.alibi_slopes,
kv_scale
kv_scale,
self.qk_matmul,
self.softmax,
self.kv_matmul,
self.key_cache.fetch_from_cache,
self.value_cache.fetch_from_cache,
)

# Reshape the output tensor.
Expand Down
10 changes: 10 additions & 0 deletions vllm/attention/ops/habana_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@ def forward_decode(
scale: float,
alibi_slopes: Optional[torch.Tensor],
kv_scale: float,
qk_op=torch.matmul,
softmax_op=torch.softmax,
kv_op=torch.matmul,
keys_fetch=ops.fetch_from_cache,
values_fetch=ops.fetch_from_cache,
) -> torch.Tensor:
block_size = value_cache.shape[1]
return ops.paged_attention_v1(
Expand All @@ -98,6 +103,11 @@ def forward_decode(
block_size,
alibi_slopes,
kv_cache_dtype,
qk_op,
softmax_op,
kv_op,
keys_fetch,
values_fetch,
)

@staticmethod
Expand Down
7 changes: 5 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,14 +366,15 @@ def _verify_args(self) -> None:
def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype == "fp8":
elif self.cache_dtype in ["fp8", "hf8"]:
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"But it may cause slight accuracy drop without scaling "
"factors. FP8_E5M2 (without scaling) is only supported on "
"cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
"is instead supported for common inference criteria.")
"is instead supported for common inference criteria. "
"FP8_E4M3 is also supported on hpu.")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

Expand Down Expand Up @@ -474,10 +475,12 @@ class LoadConfig:
mainly for profiling.
"tensorizer" will use CoreWeave's tensorizer library for
fast weight loading.
device: Device on which weights are loaded.
"""

load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
download_dir: Optional[str] = None
device: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(
default_factory=dict)

Expand Down
13 changes: 11 additions & 2 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class EngineArgs:
trust_remote_code: bool = False
download_dir: Optional[str] = None
load_format: str = 'auto'
weights_load_device: Optional[str] = None
dtype: str = 'auto'
kv_cache_dtype: str = 'auto'
quantization_param_path: Optional[str] = None
Expand Down Expand Up @@ -168,6 +169,11 @@ def add_cli_args(
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave which assumes tensorizer_uri is set to the location of '
'the serialized weights.')
parser.add_argument("--weights-load-device",
type=str,
default=EngineArgs.weights_load_device,
choices=["cuda", "neuron", "hpu", "cpu"],
help='Device on which weights are loaded.')
parser.add_argument(
'--dtype',
type=str,
Expand All @@ -186,12 +192,13 @@ def add_cli_args(
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8'],
choices=['auto', 'fp8', 'hf8'],
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. FP8_E5M2 (without scaling) is only supported on cuda '
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria.')
'supported for common inference criteria. FP8_E4M3 is also supported '
'on hpu.')
parser.add_argument(
'--quantization-param-path',
type=nullable_str,
Expand Down Expand Up @@ -574,9 +581,11 @@ def create_engine_config(self, ) -> EngineConfig:
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None

device = device_config.device_type if self.weights_load_device is None else self.weights_load_device
load_config = LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
device=device,
model_loader_extra_config=self.model_loader_extra_config,
)

Expand Down
8 changes: 6 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def __init__(
"tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
"max_seq_len=%d, download_dir=%r, load_format=%s, "
"tensor_parallel_size=%d, disable_custom_all_reduce=%s, "
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"quantization=%s, weights_load_device=%s, enforce_eager=%s, "
"kv_cache_dtype=%s, quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d, served_model_name=%s)",
vllm.__version__,
model_config.model,
Expand All @@ -123,6 +123,7 @@ def __init__(
parallel_config.tensor_parallel_size,
parallel_config.disable_custom_all_reduce,
model_config.quantization,
load_config.device,
model_config.enforce_eager,
cache_config.cache_dtype,
model_config.quantization_param_path,
Expand Down Expand Up @@ -537,6 +538,9 @@ def _process_model_outputs(
request_outputs.append(request_output)
return request_outputs

def finish_measurements(self):
self.model_executor.finish_measurements()

def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.

Expand Down
3 changes: 3 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,9 @@ def set_tokenizer(
) -> None:
self.llm_engine.tokenizer.tokenizer = tokenizer

def finish_measurements(self):
self.llm_engine.finish_measurements()

def generate(
self,
prompts: Optional[Union[str, List[str]]] = None,
Expand Down
9 changes: 9 additions & 0 deletions vllm/executor/habana_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def initialize_cache(self, num_gpu_blocks : int, num_cpu_blocks) -> None:
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
logger.info(f"init_cache_engine took {cache_init_m.get_summary_string()}")

def finish_measurements(self):
self.driver_worker.finish_measurements()

def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
Expand Down Expand Up @@ -128,6 +131,12 @@ def check_health(self) -> None:
# it's running.
return

def shutdown(self) -> None:
self.driver_worker.shutdown_hqt()

def __del__(self):
self.shutdown()


class HabanaExecutorAsync(HabanaExecutor, ExecutorAsyncBase):

Expand Down
3 changes: 3 additions & 0 deletions vllm/executor/ray_habana_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers)

def finish_measurements(self):
self._run_workers("finish_measurements")

def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
Expand Down
13 changes: 13 additions & 0 deletions vllm/hpu/cache_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,19 @@ def reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, dtype, i
value_cache.index_put_((indices, offsets), value)


def prepare_to_cache(cache, slot_mapping):
block_size = cache.size(1)
slot_mapping = slot_mapping.flatten()
indices = torch.div(slot_mapping, block_size, rounding_mode="floor")
offsets = torch.fmod(slot_mapping, block_size)

return indices, offsets


def insert_or_update_cache(input, cache, block_indices, block_offsets):
cache.index_put_((block_indices, block_offsets), input)


def swap_blocks(src, dst, block_mapping):
index_src = torch.zeros((1,), dtype=torch.int32, device=src.device)
index_dst = torch.zeros((1,), dtype=torch.int32, device=dst.device)
Expand Down
18 changes: 8 additions & 10 deletions vllm/hpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def fetch_from_cache(cache, blocks, permutations):
return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))]


@hpu_utils.with_mark_steps
def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, kv_cache_dtype=None) -> None:
def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block_tables, context_lens, block_size, alibi_slopes, kv_cache_dtype=None,
qk_matmul_op=torch.matmul, softmax_op=torch.softmax, kv_matmul_op=torch.matmul, keys_fetch_func=fetch_from_cache, values_fetch_func=fetch_from_cache) -> None:
seq_len = block_tables.size(1)
batch_size, query_heads, _ = query.shape
_, _, kv_heads, _ = key_cache.shape
Expand All @@ -48,26 +48,24 @@ def paged_attention_v1(query, key_cache, value_cache, head_mapping, scale, block
.view(batch_size, 1, 1, -1))
query.mul_(scale)
query = query.unsqueeze(-2)
keys = fetch_from_cache(key_cache, block_tables, (0, 2, 3, 1))
keys = keys_fetch_func(key_cache, block_tables, (0, 2, 3, 1))
if query_heads != kv_heads:
query = query.unflatten(1, (kv_heads, -1))
keys = [k.unflatten(1, (kv_heads, 1)) for k in keys]
mask = mask.unsqueeze(2)
attn_weights = [qk_matmul_op(query, k) for k in keys]
attn_weights = softmax_op(torch.cat(attn_weights, dim=-1).masked_fill(mask, min_inf),
dim=-1)

attn_weights = [torch.matmul(query, k) for k in keys]
attn_weights = (torch.cat(attn_weights, dim=-1)
.masked_fill(mask, min_inf)
.softmax(dim=-1))

values = fetch_from_cache(value_cache, block_tables, (0, 2, 1, 3))
values = values_fetch_func(value_cache, block_tables, (0, 2, 1, 3))
if PA_SPLIT_VALUE:
attn_weights = attn_weights.split(block_size, dim=-1)
else:
values = [torch.cat(values, dim=-2)]
attn_weights = [attn_weights]
if query_heads != kv_heads:
values = [v.unflatten(1, (kv_heads, 1)) for v in values]
attn_weights = [torch.matmul(a, v) for a, v in zip(attn_weights, values)]
attn_weights = [kv_matmul_op(a, v) for a, v in zip(attn_weights, values)]
if query_heads != kv_heads:
attn_weights = [a.flatten(1, 2) for a in attn_weights]
attn_weights = sum(attn_weights)
Expand Down
32 changes: 31 additions & 1 deletion vllm/hpu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# LICENSE file in the root directory of this source tree.
###############################################################################

import torch
import habana_frameworks.torch as htorch
from vllm.hpu.cache_ops import insert_or_update_cache

def with_mark_steps(fn):
def wrapped(*args, **kwargs):
Expand Down Expand Up @@ -96,4 +98,32 @@ def tfidf_backend(recipes):
cm.ax_.set_ylabel("Source recipe number")
plt.title(f'Recipe similarity ({backend_name})')
return plt
# plt.savefig('similarity.png')
# plt.savefig('similarity.png')


class Matmul(torch.nn.Module):
def __init__(self):
super(Matmul, self).__init__()

def forward(self, x, y):
return torch.matmul(x, y)


class Softmax(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, dim = None, inv_head = None):
return torch.softmax(x, dim)


class VLLMKVCache(torch.nn.Module):
def __init__(self):
super(VLLMKVCache, self).__init__()

def forward(self, input, cache, block_indices, block_offset):
insert_or_update_cache(input, cache, block_indices, block_offset)
return cache

def fetch_from_cache(self, cache, blocks, permutations):
return [cache.index_select(0, blocks[:, i]).permute(permutations) for i in range(blocks.size(1))]
Loading