diff --git a/setup.py b/setup.py index 9afaa4b0..eb6d90cc 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ def _setup_packages() -> List: ) def _setup_install_requires() -> List: - return ["torch>=1.7.0", "transformers<=4.40", "pydantic>=1.8.2,<2.0.0", "sparsezoo"] + return ["torch>=1.7.0", "transformers<=4.40", "pydantic>=1.8.2,<2.0.0", "sparsezoo-nightly"] def _setup_extras() -> Dict: return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0"]} diff --git a/src/sparsetensors/compressors/base.py b/src/sparsetensors/compressors/base.py index 7b013827..39f725e8 100644 --- a/src/sparsetensors/compressors/base.py +++ b/src/sparsetensors/compressors/base.py @@ -55,24 +55,6 @@ def decompress(self, model_path: str) -> Generator[Tuple[str, Tensor], None, Non """ raise NotImplementedError() - @staticmethod - def replace_layer(param_name: str, data: Tensor, model: Module): - """ - Overwrites a parameterized layer with a new tensor, maintaining the device of - the original parameter - - :param param_name: name of parameterized layer to replace - :param data: tensor to insert into model - :param model: pytorch model to insert data into - """ - model_device = operator.attrgetter(param_name)(model).device - new_param = Parameter(data.to(model_device)) - # TODO: Two for loops? - for name, param in model.named_parameters(): - if name == param_name: - param.data = new_param.data - return - def overwrite_weights(self, model_path: str, model: Module): """ Overwrites the weights in model with weights decompressed from model_path @@ -82,5 +64,10 @@ def overwrite_weights(self, model_path: str, model: Module): """ dense_gen = self.decompress(model_path) for name, data in tqdm(dense_gen, desc="Decompressing model"): - ModelCompressor.replace_layer(name, data, model) + # loading the decompressed weights into the model + model_device = operator.attrgetter(name)(model).device + data_new = Parameter(data.to(model_device)) + data_old = operator.attrgetter(name)(model) + data_old.data = data_new.data + setattr(model, SPARSITY_CONFIG_NAME, self.config)