-
-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core][distributed] accelerate distributed weight loading #6127
Changes from all commits
79b5348
e49d7a3
00db011
e27e5a6
b7d2888
413e2b1
6eb782c
c12cd1a
2b8a496
cfc896a
ce69701
633a697
7852efb
91695ce
da9d486
6db5418
971bb0c
907285b
ea9196e
343212b
f41b2b5
ca2a9f7
cb41625
24b0b3b
86bc4d0
1e29aaa
0180f0f
6e4c5d5
b688f00
3375589
12af9f3
7f9011d
f336117
cb90984
02a65f3
b17f6aa
0fcef6f
70afe57
4e0773c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -943,3 +943,209 @@ def parse_args(self, args=None, namespace=None): | |||
processed_args.append(arg) | ||||
|
||||
return super().parse_args(processed_args, namespace) | ||||
|
||||
|
||||
class DeferredTensor: | ||||
"""This class is a placeholder for a tensor that is not materialized yet. | ||||
When we pass the object around, it will not materialize the tensor until | ||||
torch functions are called on it. | ||||
Notable exceptions are `shape`, `dtype`, `size`, `stride` which will be | ||||
returned directly without materializing the tensor. | ||||
Notable optimization is `narrow` method which will only materialize the | ||||
tensor slice that is narrowed, reducing the disk reads. Either `torch.narrow` | ||||
or `tensor.narrow` will materialize the tensor. | ||||
|
||||
Basically, you can use instances of this class when you need values of the | ||||
tensor, but don't need in-place update of the tensor. | ||||
""" # noqa | ||||
|
||||
def __init__(self, layz_open_st, st_file, name, dtype, shape): | ||||
self.layz_open_st = layz_open_st | ||||
self.st_file = st_file | ||||
self.name = name | ||||
|
||||
# code from https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L40 # noqa | ||||
type_mapping = { | ||||
"BOOL": torch.bool, | ||||
"I8": torch.int8, | ||||
"U8": torch.uint8, | ||||
"I16": torch.int16, | ||||
"U16": torch.uint16, | ||||
"I32": torch.int32, | ||||
"U32": torch.uint32, | ||||
"I64": torch.int64, | ||||
"U64": torch.uint64, | ||||
"F16": torch.float16, | ||||
"F32": torch.float32, | ||||
"F64": torch.float64, | ||||
"BF16": torch.bfloat16, | ||||
"F8_E4M3": torch.float8_e4m3fn, | ||||
"F8_E5M2": torch.float8_e5m2 | ||||
} | ||||
dtype = type_mapping[dtype] | ||||
shape = tuple(shape) | ||||
if shape: | ||||
self._meta_tensor = torch.zeros(*shape, dtype=dtype, device="meta") | ||||
else: | ||||
self._meta_tensor = torch.zeros(tuple(), | ||||
dtype=dtype, | ||||
device="meta") | ||||
|
||||
def __getattr__(self, name): | ||||
if name in ["shape", "dtype", "size", "stride"]: | ||||
# redirect metadata information queries to the meta tensor | ||||
return getattr(self._meta_tensor, name) | ||||
if hasattr(torch.Tensor, name): | ||||
# the rest functions will materialize the tensor and call the | ||||
# function on the materialized tensor | ||||
tensor = self.materialize() | ||||
return getattr(tensor, name) | ||||
raise AttributeError(f"Attribute {name} not found") | ||||
|
||||
def __getitem__(self, key) -> torch.Tensor: | ||||
return self.layz_open_st(self.st_file).get_slice(self.name)[key] | ||||
|
||||
def materialize(self) -> torch.Tensor: | ||||
return self.layz_open_st(self.st_file).get_tensor(self.name) | ||||
|
||||
def narrow(input, dim, start, length) -> torch.Tensor: | ||||
# `input` is a `DeferredTensor` object | ||||
# it does not use `self`, but `input` instead | ||||
# to better match https://pytorch.org/docs/stable/generated/torch.narrow.html signature # noqa | ||||
|
||||
# `DeferredTensor` will only respond to `narrow` method | ||||
# which reads the corresponding data from disk and returns | ||||
# a materialized tensor | ||||
slices = [slice(None, None, None) for x in input._meta_tensor.shape] | ||||
slices[dim] = slice(start, start + length) | ||||
return input[tuple(slices)] | ||||
|
||||
@classmethod | ||||
def __torch_function__(cls, func, types, args=(), kwargs=None): | ||||
if kwargs is None: | ||||
kwargs = {} | ||||
if func == torch.narrow: | ||||
if len(args) >= 2: | ||||
kwargs["dim"] = args[1] | ||||
if len(args) >= 3: | ||||
kwargs["start"] = args[2] | ||||
if len(args) >= 4: | ||||
kwargs["length"] = args[3] | ||||
return args[0].narrow(**kwargs) | ||||
new_args = [] | ||||
for arg in args: | ||||
if isinstance(arg, DeferredTensor): | ||||
new_args.append(arg.materialize()) | ||||
else: | ||||
new_args.append(arg) | ||||
return func(*new_args, **kwargs) | ||||
|
||||
# implement common tensor operations, except for in-place operations | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need those in the first place? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I write these code so that the deferred tensors can behave just like normal tensors. for example, some users call vllm/vllm/model_executor/models/baichuan.py Line 382 in 543aa48
When we have these functions, users' code will be minimally affected. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this will lead to issues later down the line (we now have an ill-defined subset of supported operations). I feel like explicitly requiring materialization is a better way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can refer to da9d486 for example. that is what i did before. I find that code is more intrusive, and all third-party code using vllm might break. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is too magical and being explicit would be better, but I am not going to block on this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Yard1 would you consider elaborating. I am not sure I agree that explicitly presenting the tensor I have been working on a very small PR designed to improve availability. This is done by eliminating the model load startup delay (about 1 minute on my gear, I can shave it to about 3 seconds). This is pretty straightforward to deliver, presuming that the logit operation does not need to be calculated on startup. The gist of my WIP PR is:
Would be interested in hearing your opinions on the approach and a review when ready. There are some constraints:
One problem I struggle with is not technical - its about making this model an optional feature (via runtime environment variable), or adding as a flag. Thanks that was helpful writing down my ideas in this area, I will copy this into an issue tracker. |
||||
|
||||
def __add__(self, other): | ||||
return self.materialize() + other | ||||
|
||||
def __radd__(self, other): | ||||
return other + self.materialize() | ||||
|
||||
def __sub__(self, other): | ||||
return self.materialize() - other | ||||
|
||||
def __rsub__(self, other): | ||||
return other - self.materialize() | ||||
|
||||
def __mul__(self, other): | ||||
return self.materialize() * other | ||||
|
||||
def __rmul__(self, other): | ||||
return other * self.materialize() | ||||
|
||||
def __truediv__(self, other): | ||||
return self.materialize() / other | ||||
|
||||
def __rtruediv__(self, other): | ||||
return other / self.materialize() | ||||
|
||||
def __floordiv__(self, other): | ||||
return self.materialize() // other | ||||
|
||||
def __rfloordiv__(self, other): | ||||
return other // self.materialize() | ||||
|
||||
def __mod__(self, other): | ||||
return self.materialize() % other | ||||
|
||||
def __rmod__(self, other): | ||||
return other % self.materialize() | ||||
|
||||
def __pow__(self, other): | ||||
return self.materialize()**other | ||||
|
||||
def __rpow__(self, other): | ||||
return other**self.materialize() | ||||
|
||||
def __matmul__(self, other): | ||||
return self.materialize() @ other | ||||
|
||||
def __rmatmul__(self, other): | ||||
return other @ self.materialize() | ||||
|
||||
def __and__(self, other): | ||||
return self.materialize() & other | ||||
|
||||
def __rand__(self, other): | ||||
return other & self.materialize() | ||||
|
||||
def __or__(self, other): | ||||
return self.materialize() | other | ||||
|
||||
def __ror__(self, other): | ||||
return other | self.materialize() | ||||
|
||||
def __xor__(self, other): | ||||
return self.materialize() ^ other | ||||
|
||||
def __rxor__(self, other): | ||||
return other ^ self.materialize() | ||||
|
||||
def __lshift__(self, other): | ||||
return self.materialize() << other | ||||
|
||||
def __rlshift__(self, other): | ||||
return other << self.materialize() | ||||
|
||||
def __rshift__(self, other): | ||||
return self.materialize() >> other | ||||
|
||||
def __rrshift__(self, other): | ||||
return other >> self.materialize() | ||||
|
||||
def __eq__(self, other): | ||||
return self.materialize() == other | ||||
|
||||
def __ne__(self, other): | ||||
return self.materialize() != other | ||||
|
||||
def __lt__(self, other): | ||||
return self.materialize() < other | ||||
|
||||
def __le__(self, other): | ||||
return self.materialize() <= other | ||||
|
||||
def __gt__(self, other): | ||||
return self.materialize() > other | ||||
|
||||
def __ge__(self, other): | ||||
return self.materialize() >= other | ||||
|
||||
def __neg__(self): | ||||
return -self.materialize() | ||||
|
||||
def __pos__(self): | ||||
return +self.materialize() | ||||
|
||||
def __abs__(self): | ||||
return abs(self.materialize()) | ||||
|
||||
def __invert__(self): | ||||
return ~self.materialize() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why isn't it possible to subclass/use Meta tensor for example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Meta tensor is used to respond metadata related query now.