Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][distributed] accelerate distributed weight loading #6127

Closed
wants to merge 39 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
79b5348
add deferred tensor
youkaichao Jul 4, 2024
e49d7a3
fix tuple
youkaichao Jul 4, 2024
00db011
update model
youkaichao Jul 4, 2024
e27e5a6
format
youkaichao Jul 4, 2024
b7d2888
del file
youkaichao Jul 4, 2024
413e2b1
use torch.Size
youkaichao Jul 4, 2024
6eb782c
add more convert
youkaichao Jul 4, 2024
c12cd1a
more fix
youkaichao Jul 4, 2024
2b8a496
rename
youkaichao Jul 4, 2024
cfc896a
ensure_tensor
youkaichao Jul 4, 2024
ce69701
use type hint
youkaichao Jul 4, 2024
633a697
use copy_
youkaichao Jul 4, 2024
7852efb
add logging
youkaichao Jul 4, 2024
91695ce
update default loader
youkaichao Jul 4, 2024
da9d486
fix using loaded_weight in load_weights
youkaichao Jul 4, 2024
6db5418
add more duck typing support
youkaichao Jul 5, 2024
971bb0c
revoke moe change
youkaichao Jul 5, 2024
907285b
revoke linear change
youkaichao Jul 5, 2024
ea9196e
revoke embedding change
youkaichao Jul 5, 2024
343212b
revoke more changes
youkaichao Jul 5, 2024
f41b2b5
add code
youkaichao Jul 5, 2024
ca2a9f7
update
youkaichao Jul 5, 2024
cb41625
test view
youkaichao Jul 5, 2024
24b0b3b
add torch.reshape
youkaichao Jul 5, 2024
86bc4d0
add t()
youkaichao Jul 5, 2024
1e29aaa
revoke more
youkaichao Jul 5, 2024
0180f0f
revoke more
youkaichao Jul 5, 2024
6e4c5d5
finish revoke
youkaichao Jul 5, 2024
b688f00
finish and try
youkaichao Jul 5, 2024
3375589
bugfix
youkaichao Jul 5, 2024
12af9f3
avoid bug of safetensors
youkaichao Jul 8, 2024
7f9011d
remove unused code
youkaichao Jul 8, 2024
f336117
update
youkaichao Jul 10, 2024
cb90984
lazy open file
youkaichao Jul 10, 2024
02a65f3
add more comments
youkaichao Jul 10, 2024
b17f6aa
bugfix
youkaichao Jul 10, 2024
0fcef6f
add more comments
youkaichao Jul 10, 2024
70afe57
bugfix
youkaichao Jul 10, 2024
4e0773c
Merge branch 'main' into defer_tensor
youkaichao Jul 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tests/distributed/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import ray
import torch

import vllm.envs as envs
from vllm.model_executor.model_loader.weight_utils import (
safetensors_weights_iterator)
from vllm.utils import (cuda_device_count_stateless, is_hip,
update_environment_variables)

Expand Down Expand Up @@ -36,3 +39,49 @@ def test_cuda_device_count_stateless():
assert ray.get(actor.get_count.remote()) == 1
ray.get(actor.set_cuda_visible_devices.remote(""))
assert ray.get(actor.get_count.remote()) == 0


def test_deferred_tensor():
from safetensors import safe_open
from safetensors.torch import save_file
tensors = {
"scalar": torch.ones(tuple()),
"vector": torch.ones(2),
"matrix": torch.ones((2, 3)),
"tensor": torch.ones((2, 3, 4)),
}
save_file(tensors, "model.safetensors")

for name, dt in safetensors_weights_iterator(["model.safetensors"]):
with safe_open("model.safetensors",
framework="pt") as f: # type: ignore
real_tensor = f.get_tensor(name)
real_tensor.copy_(dt) # test we can use `copy_`
stacked = torch.stack([real_tensor, real_tensor])
stacked[0] = dt # test we can use `__setitem__` to assign
if name != "scalar":
real_tensor[1:] = dt[1:] # test we can assign slices
if name in ["matrix", "tensor"]:
real_norm = torch.nn.functional.normalize(real_tensor)
dt_norm = torch.nn.functional.normalize(dt)
assert torch.allclose(real_norm,
dt_norm) # test we can use `normalize`
assert torch.allclose(real_tensor.cpu(),
dt.cpu()) # test we can move to device
assert torch.allclose(
real_tensor.to(dtype=torch.float64),
dt.to(dtype=torch.float64)) # test we can change dtype

assert torch.allclose(real_tensor + 1, dt + 1)
assert torch.allclose(real_tensor.float(),
dt.float()) # test we can use `.float()`
assert torch.allclose(real_tensor.data,
dt.data) # test we can use `.data`
assert torch.allclose(real_tensor.view(-1),
dt.view(-1)) # test we can use `view`
assert torch.allclose(torch.reshape(real_tensor, (-1, )),
torch.reshape(
dt, (-1, ))) # test we can use `reshape`
if name != "tensor":
assert torch.allclose(real_tensor.t(),
dt.t()) # test we can use `t()`
65 changes: 44 additions & 21 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import tempfile
from collections import defaultdict
from typing import Any, Generator, Iterable, List, Optional, Tuple
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple

import filelock
import huggingface_hub.constants
Expand All @@ -22,6 +22,7 @@
from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
from vllm.model_executor.layers.quantization.schema import QuantParamSchema
from vllm.utils import DeferredTensor

logger = init_logger(__name__)

Expand Down Expand Up @@ -350,15 +351,52 @@ def np_cache_weights_iterator(
yield name, torch.from_numpy(param)


def _parse_metadata_from_safetensors(
filepath: str) -> Dict[str, Dict[str, Any]]:
# format from https://huggingface.co/docs/safetensors/en/index#format
with open(filepath, "rb") as f:
size = int.from_bytes(f.read(8), "little")
data = json.loads(f.read(size).decode('utf-8'))
return data


def safetensors_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
"""Iterate over the weights in the model safetensor files.
NOTE: we read the file as lazily as possible. If this process
does not need any weight inside a safetensor file (e.g. pipeline
parallel), that file is not opened by safetensors library at all.
"""
st_handles: Dict[str, Any] = {}

def layz_open_st(filename):
# lazily open safetensor files
if filename not in st_handles:
st_handles[filename] = safe_open(filename,
framework="pt").__enter__()
return st_handles[filename]

name_and_tensors = []
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
data = _parse_metadata_from_safetensors(st_file)
for k, v in data.items():
if k == "__metadata__":
continue
dtype = v["dtype"]
shape = v["shape"]
name_and_tensors.append(
[k, DeferredTensor(layz_open_st, st_file, k, dtype, shape)])
for name, v in name_and_tensors:
# we actually return the DeferredTensor here
# but use `torch.Tensor` as the type hint to avoid
# changing too many user-side code
# users can use this value just like a torch.Tensor,
# except that slicing and `narrow` are optimized for I/O
yield name, v

for v in st_handles.values():
v.__exit__(None, None, None) # type: ignore


def pt_weights_iterator(
Expand Down Expand Up @@ -413,21 +451,6 @@ def kv_cache_scales_loader(
return []


def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
"""convert PySafeSlice object from safetensors to torch.Tensor

PySafeSlice object supports indexing, which is done before loading the
actual tensor and can reduce the amount of memory being read into the
memory. However, it does not support more advanced functionalities
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
tensor with these more complicated operators, we need to convert to
tensor first.
"""
if not isinstance(x, torch.Tensor):
x = x[:]
return x


def default_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
Expand Down
206 changes: 206 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,3 +943,209 @@ def parse_args(self, args=None, namespace=None):
processed_args.append(arg)

return super().parse_args(processed_args, namespace)


class DeferredTensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why isn't it possible to subclass/use Meta tensor for example?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Meta tensor is used to respond metadata related query now.

"""This class is a placeholder for a tensor that is not materialized yet.
When we pass the object around, it will not materialize the tensor until
torch functions are called on it.
Notable exceptions are `shape`, `dtype`, `size`, `stride` which will be
returned directly without materializing the tensor.
Notable optimization is `narrow` method which will only materialize the
tensor slice that is narrowed, reducing the disk reads. Either `torch.narrow`
or `tensor.narrow` will materialize the tensor.

Basically, you can use instances of this class when you need values of the
tensor, but don't need in-place update of the tensor.
""" # noqa

def __init__(self, layz_open_st, st_file, name, dtype, shape):
self.layz_open_st = layz_open_st
self.st_file = st_file
self.name = name

# code from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L40 # noqa
type_mapping = {
"BOOL": torch.bool,
"I8": torch.int8,
"U8": torch.uint8,
"I16": torch.int16,
"U16": torch.uint16,
"I32": torch.int32,
"U32": torch.uint32,
"I64": torch.int64,
"U64": torch.uint64,
"F16": torch.float16,
"F32": torch.float32,
"F64": torch.float64,
"BF16": torch.bfloat16,
"F8_E4M3": torch.float8_e4m3fn,
"F8_E5M2": torch.float8_e5m2
}
dtype = type_mapping[dtype]
shape = tuple(shape)
if shape:
self._meta_tensor = torch.zeros(*shape, dtype=dtype, device="meta")
else:
self._meta_tensor = torch.zeros(tuple(),
dtype=dtype,
device="meta")

def __getattr__(self, name):
if name in ["shape", "dtype", "size", "stride"]:
# redirect metadata information queries to the meta tensor
return getattr(self._meta_tensor, name)
if hasattr(torch.Tensor, name):
# the rest functions will materialize the tensor and call the
# function on the materialized tensor
tensor = self.materialize()
return getattr(tensor, name)
raise AttributeError(f"Attribute {name} not found")

def __getitem__(self, key) -> torch.Tensor:
return self.layz_open_st(self.st_file).get_slice(self.name)[key]

def materialize(self) -> torch.Tensor:
return self.layz_open_st(self.st_file).get_tensor(self.name)

def narrow(input, dim, start, length) -> torch.Tensor:
# `input` is a `DeferredTensor` object
# it does not use `self`, but `input` instead
# to better match https://pytorch.org/docs/stable/generated/torch.narrow.html signature # noqa

# `DeferredTensor` will only respond to `narrow` method
# which reads the corresponding data from disk and returns
# a materialized tensor
slices = [slice(None, None, None) for x in input._meta_tensor.shape]
slices[dim] = slice(start, start + length)
return input[tuple(slices)]

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func == torch.narrow:
if len(args) >= 2:
kwargs["dim"] = args[1]
if len(args) >= 3:
kwargs["start"] = args[2]
if len(args) >= 4:
kwargs["length"] = args[3]
return args[0].narrow(**kwargs)
new_args = []
for arg in args:
if isinstance(arg, DeferredTensor):
new_args.append(arg.materialize())
else:
new_args.append(arg)
return func(*new_args, **kwargs)

# implement common tensor operations, except for in-place operations
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need those in the first place?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I write these code so that the deferred tensors can behave just like normal tensors. for example, some users call torch.nn.functional.normalize on tensors loaded from disk :

loaded_weight = torch.nn.functional.normalize(

When we have these functions, users' code will be minimally affected.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will lead to issues later down the line (we now have an ill-defined subset of supported operations). I feel like explicitly requiring materialization is a better way.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can refer to da9d486 for example. that is what i did before. I find that code is more intrusive, and all third-party code using vllm might break.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is too magical and being explicit would be better, but I am not going to block on this.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Yard1 would you consider elaborating. I am not sure I agree that explicitly presenting the tensor bytearrray content in a person's mental model ordered wrt consumption of said bytearray is wise, especially for sarge a large read-only buffer. I am of opinion that any large read-only memory is a massive target for performance optimization.

I have been working on a very small PR designed to improve availability. This is done by eliminating the model load startup delay (about 1 minute on my gear, I can shave it to about 3 seconds). This is pretty straightforward to deliver, presuming that the logit operation does not need to be calculated on startup.

The gist of my WIP PR is:

  • remove logits calculation (which takes 12 seconds on my very beast GPU server with 2 A40 running llama3-70b-w8a8.
  • Rely on safetensors exclusively.
  • Modify, slightly, the management of pytorch storage such that IF a storage buffer address is accessed with any Tensor.to('cuda'), there will be some intelligence as it relates to h2d memcpy..
  • Modify, slightly, to('cuda') calls to send Tensor objects to cuda devices, but not to execute a cuda host-to-device copy (as is default with Pytorch).

Would be interested in hearing your opinions on the approach and a review when ready. There are some constraints:

  • safetensors first implementation
  • cuda first implementation
  • cuda must have uvm loaded
  • dedicated address space consumption of model size
  • dedicated (initially) memory consumption of model size (this is an obnoxious constraint, and is not permanent).
  • no run-time quantization. (I prefer static compressed models, although I understand the value of some types of dynamic calculations of decompression, so we can kick this around).

One problem I struggle with is not technical - its about making this model an optional feature (via runtime environment variable), or adding as a flag.

Thanks that was helpful writing down my ideas in this area, I will copy this into an issue tracker.
cc/ @sdake
cc/ @robertgshaw2-neuralmagic


def __add__(self, other):
return self.materialize() + other

def __radd__(self, other):
return other + self.materialize()

def __sub__(self, other):
return self.materialize() - other

def __rsub__(self, other):
return other - self.materialize()

def __mul__(self, other):
return self.materialize() * other

def __rmul__(self, other):
return other * self.materialize()

def __truediv__(self, other):
return self.materialize() / other

def __rtruediv__(self, other):
return other / self.materialize()

def __floordiv__(self, other):
return self.materialize() // other

def __rfloordiv__(self, other):
return other // self.materialize()

def __mod__(self, other):
return self.materialize() % other

def __rmod__(self, other):
return other % self.materialize()

def __pow__(self, other):
return self.materialize()**other

def __rpow__(self, other):
return other**self.materialize()

def __matmul__(self, other):
return self.materialize() @ other

def __rmatmul__(self, other):
return other @ self.materialize()

def __and__(self, other):
return self.materialize() & other

def __rand__(self, other):
return other & self.materialize()

def __or__(self, other):
return self.materialize() | other

def __ror__(self, other):
return other | self.materialize()

def __xor__(self, other):
return self.materialize() ^ other

def __rxor__(self, other):
return other ^ self.materialize()

def __lshift__(self, other):
return self.materialize() << other

def __rlshift__(self, other):
return other << self.materialize()

def __rshift__(self, other):
return self.materialize() >> other

def __rrshift__(self, other):
return other >> self.materialize()

def __eq__(self, other):
return self.materialize() == other

def __ne__(self, other):
return self.materialize() != other

def __lt__(self, other):
return self.materialize() < other

def __le__(self, other):
return self.materialize() <= other

def __gt__(self, other):
return self.materialize() > other

def __ge__(self, other):
return self.materialize() >= other

def __neg__(self):
return -self.materialize()

def __pos__(self):
return +self.materialize()

def __abs__(self):
return abs(self.materialize())

def __invert__(self):
return ~self.materialize()
8 changes: 6 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ def __init__(
self.flashinfer_prefill_wrapper = None

def load_model(self) -> None:
logger.info("Start loading model")
start_time = time.perf_counter()
with CudaMemoryProfiler() as m:
self.model = get_model(model_config=self.model_config,
device_config=self.device_config,
Expand All @@ -261,10 +263,12 @@ def load_model(self) -> None:
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config,
cache_config=self.cache_config)
end_time = time.perf_counter()
elapsed_time = end_time - start_time

self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))
logger.info("Loading model weights took %.4f GB memory and %.4f sec",
self.model_memory_usage / float(2**30), elapsed_time)

if self.lora_config:
assert supports_lora(self.model), "Model does not support LoRA"
Expand Down
Loading