Skip to content

Commit

Permalink
feat: auto download segformer weights
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau authored and pnsuau committed Apr 8, 2022
1 parent 85171e1 commit 083cc5e
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 2 deletions.
3 changes: 3 additions & 0 deletions models/modules/projected_d/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ def create_segformer_model(model_name, config_path, weight_path):
except:
weights = torch.load(weight_path)

if "state_dict" in weights:
weights = weights["state_dict"]

segformer = build_segmentor(cfg.model, train_cfg=None, test_cfg=cfg.get("test_cfg"))
model = segformer.backbone

Expand Down
29 changes: 29 additions & 0 deletions models/modules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from torch.nn import init
import functools
from torch.optim import lr_scheduler
import wget
import os

##########################################################
# Fonctions used for networks initialisation
Expand Down Expand Up @@ -198,3 +200,30 @@ def normal_init(m, mean, std):
if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
m.weight.data.normal_(mean, std)
m.bias.data.zero_()


segformer_weights = {
"segformer_mit-b0.pth": "https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b0_512x512_160k_ade20k/segformer_mit-b0_512x512_160k_ade20k_20210726_101530-8ffa8fda.pth",
"segformer_mit-b1.pth": "https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b1_512x512_160k_ade20k/segformer_mit-b1_512x512_160k_ade20k_20210726_112106-d70e859d.pth",
"segformer_mit-b2.pth": "https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b2_512x512_160k_ade20k/segformer_mit-b2_512x512_160k_ade20k_20210726_112103-cbd414ac.pth",
"segformer_mit-b3.pth": "https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b3_512x512_160k_ade20k/segformer_mit-b3_512x512_160k_ade20k_20210726_081410-962b98d2.pth",
"segformer_mit-b4.pth": "https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b4_512x512_160k_ade20k/segformer_mit-b4_512x512_160k_ade20k_20210728_183055-7f509d7d.pth",
"segformer_mit-b5.pth": "https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_512x512_160k_ade20k/segformer_mit-b5_512x512_160k_ade20k_20210726_145235-94cedf59.pth",
"segformer_mit-b5_640.pth": "https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_640x640_160k_ade20k/segformer_mit-b5_640x640_160k_ade20k_20210801_121243-41d2845b.pth",
}


def download_segformer_weight(path):
for i in range(2, len(path.split("/"))):
temp = path.split("/")[:i]
cur_path = "/".join(temp)
if not os.path.isdir(cur_path):
os.mkdir(cur_path)
model_name = path.split("/")[-1]
if model_name in segformer_weights:
wget.download(segformer_weights[model_name], path)
else:
raise NameError(
"There is no pretrained weight to download for %s, you need to provide a path to segformer weights."
% model_name
)
16 changes: 14 additions & 2 deletions models/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
init_weights,
get_norm_layer,
get_weights,
download_segformer_weight,
)

from .modules.resnet_architecture.resnet_generator import ResnetGenerator
Expand Down Expand Up @@ -345,11 +346,17 @@ def define_D(
)
return net
elif netD == "projected_d": # D in projected feature space
weight_path = os.path.join(jg_dir, D_proj_weight_segformer)
if D_proj_network_type == "segformer" and not os.path.exists(weight_path):
print(
"Downloading pretrained segformer weights for projected D feature extractor."
)
download_segformer_weight(weight_path)
net = ProjectedDiscriminator(
D_proj_network_type,
interp=224 if data_crop_size < 224 else D_proj_interp,
config_path=os.path.join(jg_dir, D_proj_config_segformer),
weight_path=os.path.join(jg_dir, D_proj_weight_segformer),
weight_path=weight_path,
)
return net # no init since custom frozen backbone
else:
Expand Down Expand Up @@ -414,7 +421,12 @@ def define_f(
num_classes=f_s_semantic_nclasses,
final_conv=False,
)
weights = get_weights(os.path.join(jg_dir, f_s_weight_segformer))
weight_path = os.path.join(jg_dir, f_s_weight_segformer)
if not os.path.exists(weight_path):
print("Downloading pretrained segformer weights for f_s.")
download_segformer_weight(weight_path)

weights = get_weights(weight_path)
net.net.load_state_dict(weights, strict=False)
return net

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ visdom==0.1.8.9
torchviz
imgaug
dominate
wget

0 comments on commit 083cc5e

Please sign in to comment.