From 25e9cd54fcf05256a8e5334849aa9adee5feef97 Mon Sep 17 00:00:00 2001 From: Connor Holmes Date: Wed, 15 Nov 2023 19:40:20 +0000 Subject: [PATCH] Ensure all parameters are aligned --- .../v2/model_implementations/flat_model_helpers.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py index dbec911230f5..f9da7ac5d23e 100644 --- a/deepspeed/inference/v2/model_implementations/flat_model_helpers.py +++ b/deepspeed/inference/v2/model_implementations/flat_model_helpers.py @@ -16,6 +16,13 @@ from ..inference_utils import elem_size +def pad_to_aligned_offset(offset: int, alignment: int = 256) -> int: + """ + Pad the provided offset to a well-aligned value. + """ + return ((offset + alignment - 1) // alignment) * alignment + + class TensorMetadata(DeepSpeedConfigModel): """ A class to represent a tensor specification. @@ -149,7 +156,7 @@ def process_layer(layer_container: LayerContainer, l_name: str, cur_offset: int) strides=param.stride(), offset=cur_offset) - cur_offset += elem_size(param.dtype) * param.numel() + cur_offset += pad_to_aligned_offset(elem_size(param.dtype) * param.numel()) for t_name, tensor in param.aux_attrs.items(): param_metadata.aux_params[t_name] = TensorMetadata(dtype=str(tensor.dtype), @@ -157,7 +164,7 @@ def process_layer(layer_container: LayerContainer, l_name: str, cur_offset: int) strides=tensor.stride(), offset=cur_offset) - cur_offset += elem_size(param.dtype) * param.numel() + cur_offset += pad_to_aligned_offset(elem_size(param.dtype) * param.numel()) layer_metadata.params[p_name] = param_metadata