Skip to content

Commit

Permalink
feat: allow ViT custom resolution at D projector init
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Jun 8, 2022
1 parent 6fc1dd4 commit 82e6e83
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions models/modules/projected_d/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,22 +143,25 @@ def calc_channels(pretrained, inp_res=224):
return channels, feats


def create_timm_model(model_name, config_path, weight_path):
def create_timm_model(model_name, config_path, weight_path, img_size):
import timm

model = timm.create_model(model_name, pretrained=True)
if "vit" in model_name:
model = timm.create_model(model_name, img_size=img_size, pretrained=True)
else:
model = timm.create_model(model_name, pretrained=True)
return model


def create_clip_model(model_name, config_path, weight_path):
def create_clip_model(model_name, config_path, weight_path, img_size):
import clip

model = clip.load(model_name)

return model[0].visual.float().cpu()


def create_segformer_model(model_name, config_path, weight_path):
def create_segformer_model(model_name, config_path, weight_path, img_size):
from mmseg.models import build_segmentor
import mmcv

Expand Down Expand Up @@ -229,7 +232,7 @@ def _make_projector(
### Build pretrained feature network
projector_gen = projector_models[projector_model]
model = projector_gen["create_model_function"](
projector_gen["model_name"], config_path, weight_path
projector_gen["model_name"], config_path, weight_path, interp
)

pretrained = projector_gen["make_function"](model)
Expand Down

0 comments on commit 82e6e83

Please sign in to comment.