Skip to content

Commit

Permalink
Update _solver.py
Browse files Browse the repository at this point in the history
Fix class mapping. #2
  • Loading branch information
Peterande authored Oct 18, 2024
1 parent 142cad3 commit 64e6680
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/solver/_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,11 @@ def load_tuning_state(self, path: str):
pretrain_state_dict = state['model']

# Adjust head parameters between datasets
adjusted_state_dict = self._adjust_head_parameters(module.state_dict(), pretrain_state_dict)
stat, infos = self._matched_state(module.state_dict(), adjusted_state_dict)
try:
adjusted_state_dict = self._adjust_head_parameters(module.state_dict(), pretrain_state_dict)
stat, infos = self._matched_state(module.state_dict(), adjusted_state_dict)
except:
stat, infos = self._matched_state(module.state_dict(), pretrain_state_dict)

module.load_state_dict(stat, strict=False)
print(f'Load model.state_dict, {infos}')
Expand Down Expand Up @@ -256,4 +259,4 @@ def fit(self):
raise NotImplementedError('')

def val(self):
raise NotImplementedError('')
raise NotImplementedError('')

0 comments on commit 64e6680

Please sign in to comment.