-
-
Notifications
You must be signed in to change notification settings - Fork 4.7k
/
_pretrained.py
94 lines (77 loc) · 3.44 KB
/
_pretrained.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import copy
from collections import deque, defaultdict
from dataclasses import dataclass, field, replace, asdict
from typing import Any, Deque, Dict, Tuple, Optional, Union
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg']
@dataclass
class PretrainedCfg:
"""
"""
# weight source locations
url: Optional[Union[str, Tuple[str, str]]] = None # remote URL
file: Optional[str] = None # local / shared filesystem path
state_dict: Optional[Dict[str, Any]] = None # in-memory state dict
hf_hub_id: Optional[str] = None # Hugging Face Hub model id ('organization/model')
hf_hub_filename: Optional[str] = None # Hugging Face Hub filename (overrides default)
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
architecture: Optional[str] = None # architecture variant can be set when not implicit
tag: Optional[str] = None # pretrained tag of source
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
# input / data config
input_size: Tuple[int, int, int] = (3, 224, 224)
test_input_size: Optional[Tuple[int, int, int]] = None
min_input_size: Optional[Tuple[int, int, int]] = None
fixed_input_size: bool = False
interpolation: str = 'bicubic'
crop_pct: float = 0.875
test_crop_pct: Optional[float] = None
crop_mode: str = 'center'
mean: Tuple[float, ...] = (0.485, 0.456, 0.406)
std: Tuple[float, ...] = (0.229, 0.224, 0.225)
# head / classifier config and meta-data
num_classes: int = 1000
label_offset: Optional[int] = None
label_names: Optional[Tuple[str]] = None
label_descriptions: Optional[Dict[str, str]] = None
# model attributes that vary with above or required for pretrained adaptation
pool_size: Optional[Tuple[int, ...]] = None
test_pool_size: Optional[Tuple[int, ...]] = None
first_conv: Optional[str] = None
classifier: Optional[str] = None
license: Optional[str] = None
description: Optional[str] = None
origin_url: Optional[str] = None
paper_name: Optional[str] = None
paper_ids: Optional[Union[str, Tuple[str]]] = None
notes: Optional[Tuple[str]] = None
@property
def has_weights(self):
return self.url or self.file or self.hf_hub_id
def to_dict(self, remove_source=False, remove_null=True):
return filter_pretrained_cfg(
asdict(self),
remove_source=remove_source,
remove_null=remove_null
)
def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
filtered_cfg = {}
keep_null = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none
for k, v in cfg.items():
if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
continue
if remove_null and v is None and k not in keep_null:
continue
filtered_cfg[k] = v
return filtered_cfg
@dataclass
class DefaultCfg:
tags: Deque[str] = field(default_factory=deque) # priority queue of tags (first is default)
cfgs: Dict[str, PretrainedCfg] = field(default_factory=dict) # pretrained cfgs by tag
is_pretrained: bool = False # at least one of the configs has a pretrained source set
@property
def default(self):
return self.cfgs[self.tags[0]]
@property
def default_with_tag(self):
tag = self.tags[0]
return tag, self.cfgs[tag]