Skip to content

Commit

Permalink
More corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
metric-space committed Aug 8, 2024
1 parent a381c5a commit 17c0d44
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions mergekit/scripts/ABM/activations_based_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,23 +105,22 @@ def main(

if dtype is not None:
original_w = original_w.to(dtype=dtype)
original_w2 = original_w2.to(dtype=dtype)

w = torch.clone(original_w)

if not merge_matrix and not unmerge_matrix:
if merge_matrix is None and unmerge_matrix is None:
logging.warning(
f"❌ Weight {weight_info.name} for model has no merge or unmerge matrix !!"
)

if merge_matrix is not None:
if weight_info.is_embed:
w = w @ merge_matrix[0].T
w = w @ merge_matrix.T
else:
w = merge_matrix[0] @ w
w = merge_matrix @ w

if unmerge_matrix is not None:
w = w @ unmerge_matrix[0]
w = w @ unmerge_matrix

if torch.allclose(original_w, w):
logging.warning(
Expand Down

0 comments on commit 17c0d44

Please sign in to comment.