Skip to content

Commit

Permalink
feat: load segformer torchscript weights
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau committed Mar 15, 2022
1 parent 2ebcbc6 commit 672b341
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions models/modules/projected_d/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,19 @@ def create_segformer_model(model_name,config_path,weight_path):
import mmcv
cfg = mmcv.Config.fromfile(config_path)
cfg.model.train_cfg = None
try:
weights = torch.jit.load(weight_path).state_dict()
print("Torch script weights are detected and loaded in %s"%weight_path)
except:
weights = torch.load(weight_path)

segformer = build_segmentor(cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
mmcv.runner.load_checkpoint(
segformer, weight_path,
map_location='cpu',
)

model = segformer.backbone


weights = { key.replace("backbone.",""):value for (key,value) in weights.items() if "backbone." in key}

model.load_state_dict(weights, strict=True)

return model

projector_models = {
Expand Down

0 comments on commit 672b341

Please sign in to comment.