diff --git a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py index dbec911230f5..f9da7ac5d23e 100644 --- a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py +++ b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py @@ -16,6 +16,13 @@ from ..inference_utils import elem_size +def pad_to_aligned_offset(offset: int, alignment: int = 256) -> int: + """ + Pad the provided offset to a well-aligned value. + """ + return ((offset + alignment - 1) // alignment) * alignment + + class TensorMetadata(DeepSpeedConfigModel): """ A class to represent a tensor specification. @@ -149,7 +156,7 @@ def process_layer(layer_container: LayerContainer, l_name: str, cur_offset: int) strides=param.stride(), offset=cur_offset) - cur_offset += elem_size(param.dtype) * param.numel() + cur_offset += pad_to_aligned_offset(elem_size(param.dtype) * param.numel()) for t_name, tensor in param.aux_attrs.items(): param_metadata.aux_params[t_name] = TensorMetadata(dtype=str(tensor.dtype), @@ -157,7 +164,7 @@ def process_layer(layer_container: LayerContainer, l_name: str, cur_offset: int) strides=tensor.stride(), offset=cur_offset) - cur_offset += elem_size(param.dtype) * param.numel() + cur_offset += pad_to_aligned_offset(elem_size(param.dtype) * param.numel()) layer_metadata.params[p_name] = param_metadata