-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
79 lines (57 loc) · 2.49 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import copy
import efficientnet_repo.Yet_Another_EfficientDet_Pytorch.efficientnet.utils_extra as yet_utils
import efficientnet_repo.Yet_Another_EfficientDet_Pytorch.efficientnet.model as yet_model
import efficientnet_repo.EfficientNet_PyTorch.efficientnet_pytorch.utils as ept_utils
import efficientnet_repo.EfficientNet_PyTorch.efficientnet_pytorch.model as ept_model
_origin_from_name = ept_model.EfficientNet.from_name
def _from_name_hook(model_name, cfg=None, override_params=dict(num_classes=1000)):
override_params = copy.deepcopy(override_params)
return _origin_from_name(model_name, cfg=cfg, **override_params)
class EfficientNetModule(object):
def __init__(self, repo):
self.repo = repo
def MBConvBlock(self):
ret_cls = None
if self.repo == "yet": #Yet another efficientDet pytorch
ret_cls = yet_model.MBConvBlock
elif self.repo == "ept": #Efficientnet PyTorch
ret_cls = ept_model.MBConvBlock
else:
ret_cls = None
assert ret_cls is not None
return ret_cls
def Conv2dStaticSamePadding(self):
ret_cls = None
if self.repo == "yet":
ret_cls = yet_utils.Conv2dStaticSamePadding
elif self.repo == "ept":
ret_cls = ept_utils.Conv2dStaticSamePadding
else:
ret_cls = None
assert ret_cls is not None
return ret_cls
def EfficientNet(self, need_hook=False):
ret_cls = None
if self.repo == "yet":
ret_cls = yet_model.EfficientNet
elif self.repo == "ept":
ret_cls = ept_model.EfficientNet
if need_hook:
global _from_name_hook
ept_model.EfficientNet.from_name = _from_name_hook
assert ret_cls is not None
return ret_cls
def from_pruned(self, name, path, override_params):
efficientnet = self.EfficientNet()
return efficientnet.from_name_pruned(name, state_dict_path=path, override_params=override_params)
def test_enet_module():
yet_module = EfficientNetModule("yet")
assert yet_module.MBConvBlock() == yet_model.MBConvBlock
assert yet_module.Conv2dStaticSamePadding() == yet_utils.Conv2dStaticSamePadding
ept_module = EfficientNetModule("ept")
assert ept_module.MBConvBlock() == ept_model.MBConvBlock
assert ept_module.Conv2dStaticSamePadding() == ept_utils.Conv2dStaticSamePadding
def main():
test_enet_module()
if __name__ == "__main__":
main()