Skip to content

Commit

Permalink
simplify set_layer func, add sparsezoo-nightly dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Apr 8, 2024
1 parent 8008cf5 commit d7d9557
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 20 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _setup_packages() -> List:
)

def _setup_install_requires() -> List:
return ["torch>=1.7.0", "transformers<=4.40", "pydantic>=1.8.2,<2.0.0", "sparsezoo"]
return ["torch>=1.7.0", "transformers<=4.40", "pydantic>=1.8.2,<2.0.0", "sparsezoo-nightly"]

def _setup_extras() -> Dict:
return {"dev": ["black==22.12.0", "isort==5.8.0", "wheel>=0.36.2", "flake8>=3.8.3", "pytest>=6.0.0"]}
Expand Down
25 changes: 6 additions & 19 deletions src/sparsetensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,6 @@ def decompress(self, model_path: str) -> Generator[Tuple[str, Tensor], None, Non
"""
raise NotImplementedError()

@staticmethod
def replace_layer(param_name: str, data: Tensor, model: Module):
"""
Overwrites a parameterized layer with a new tensor, maintaining the device of
the original parameter
:param param_name: name of parameterized layer to replace
:param data: tensor to insert into model
:param model: pytorch model to insert data into
"""
model_device = operator.attrgetter(param_name)(model).device
new_param = Parameter(data.to(model_device))
# TODO: Two for loops?
for name, param in model.named_parameters():
if name == param_name:
param.data = new_param.data
return

def overwrite_weights(self, model_path: str, model: Module):
"""
Overwrites the weights in model with weights decompressed from model_path
Expand All @@ -82,5 +64,10 @@ def overwrite_weights(self, model_path: str, model: Module):
"""
dense_gen = self.decompress(model_path)
for name, data in tqdm(dense_gen, desc="Decompressing model"):
ModelCompressor.replace_layer(name, data, model)
# loading the decompressed weights into the model
model_device = operator.attrgetter(name)(model).device
data_new = Parameter(data.to(model_device))
data_old = operator.attrgetter(name)(model)
data_old.data = data_new.data

setattr(model, SPARSITY_CONFIG_NAME, self.config)

0 comments on commit d7d9557

Please sign in to comment.