-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
37 lines (31 loc) · 1.29 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def match_state_dict(
state_dict_a,
state_dict_b
):
""" Filters state_dict_b to contain only states that are present in state_dict_a.
Matching happens according to two criteria:
- Is the key present in state_dict_a?
- Does the state with the same key in state_dict_a have the same shape?
Returns
(matched_state_dict, unmatched_state_dict)
States in matched_state_dict contains states from state_dict_b that are also
in state_dict_a and unmatched_state_dict contains states that have no
corresponding state in state_dict_a.
In addition: state_dict_b = matched_state_dict U unmatched_state_dict.
"""
matched_state_dict = {
key: state
for (key, state) in state_dict_b.items()
if key in state_dict_a and state.shape == state_dict_a[key].shape
}
unmatched_state_dict = {
key: state
for (key, state) in state_dict_b.items()
if key not in matched_state_dict
}
return matched_state_dict, unmatched_state_dict
def compute_num_basis(nx, nf, group_strategy, compression_ratio):
compression_ratio = 1 - compression_ratio / 100
total = nx * nf * group_strategy
num_basis = (total * compression_ratio) // (nx + nf * group_strategy)
return int(num_basis)