-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable HF pretrained backbones #31145
Enable HF pretrained backbones #31145
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
450e7c1
to
2d4e302
Compare
@@ -333,16 +333,6 @@ config = MaskFormerConfig(backbone="microsoft/resnet50", use_pretrained_backbone | |||
model = MaskFormerForInstanceSegmentation(config) # head | |||
``` | |||
|
|||
You could also load the backbone config separately and then pass it to the model config. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed as it's repeated and included in the section about loading pretrained backbones, but the example loads randomly initialized weights
fe0a50a
to
3763618
Compare
@@ -50,7 +50,7 @@ def __init__(self, config, **kwargs): | |||
if config.backbone is None: | |||
raise ValueError("backbone is not set in the config. Please set it to a timm model name.") | |||
|
|||
if config.backbone not in timm.list_models(): | |||
if config.backbone.split(".")[0] not in timm.list_models(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is because certain timm checkpoints will have the base model + a bunch of specific model specifications e.g. vit_large_patch14_dinov2.lvd142m
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, maybe worth adding this comment to the code or doing it outside if
statement with a clear variable name?
|
||
config = MaskFormerConfig(backbone="microsoft/resnet50", use_pretrained_backbone=True) # backbone and neck config | ||
config = MaskFormerConfig(backbone="microsoft/resnet-50", use_pretrained_backbone=True) # backbone and neck config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to the correct checkpoint
self.backbone = AutoBackbone.from_config( | ||
config.backbone_config, attn_implementation=config._attn_implementation | ||
) | ||
self.backbone = load_backbone(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All backbones should be loaded through load_backbone
. Being able to propogate attn_implementation
should be done in a follow up (and wasn't being used for depth anything)
969fb8c
to
2c84a2b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great, thanks for working on this! Just a few comments, nothing critical
backbone_model_type = None | ||
if config.backbone is not None: | ||
backbone_model_type = config.backbone | ||
elif config.backbone_config is not None: | ||
backbone_model_type = config.backbone_config.model_type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a chance that backbone
and backbone_config
are both None
by misconfiguration? Should we throw an error here or check in verify_backbone_config_arguments
?
If backbone_model_type
will remain None
, then if "resnet" in backbone_model_type:
will raise an error:
TypeError: argument of type 'NoneType' is not iterable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a chance that backbone and backbone_config are both None by misconfiguration?
It shouldn't happen from a fresh config, as verify_backbone_config_arguments
will check that.
It's possible something like this could happen:
from transformers import ConditionalDetrConfig, ConditionalDetrModel
config = ConditionalDetrConfig()
config.backbone = None
config.backbone_config = None
model = ConditionalDetrModel(config)
as it's not uncommon to modify configs post-creation.
If backbone_model_type will remain None, then if "resnet" in backbone_model_type: will raise an error:
Good point. I'll add in an exception is neither are set here and for the other, similar bits of logic
@@ -182,7 +182,7 @@ def __init__( | |||
|
|||
use_autobackbone = False | |||
if self.is_hybrid: | |||
if backbone_config is None and backbone is None: | |||
if backbone_config is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, in case we use if
-if
instead of if
-elif
we can remove logging and config initialization for the first if
. This will be handled by the second if
:
if backbone_config is None:
- logger.info("Initializing the config with a `BiT` backbone.")
backbone_config = {
"global_padding": "same",
"layer_type": "bottleneck",
"depths": [3, 4, 9],
"out_features": ["stage1", "stage2", "stage3"],
"embedding_dynamic_padding": True,
}
- backbone_config = BitConfig(**backbone_config)
if isinstance(backbone_config, dict):
logger.info("Initializing the config with a `BiT` backbone.")
backbone_config = BitConfig(**backbone_config)
@@ -50,7 +50,7 @@ def __init__(self, config, **kwargs): | |||
if config.backbone is None: | |||
raise ValueError("backbone is not set in the config. Please set it to a timm model name.") | |||
|
|||
if config.backbone not in timm.list_models(): | |||
if config.backbone.split(".")[0] not in timm.list_models(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, maybe worth adding this comment to the code or doing it outside if
statement with a clear variable name?
Use `use_timm_backbone=True` and `use_pretrained_backbone=True` to load pretrained timm weights for the backbone. | ||
|
||
```python | ||
from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation | ||
|
||
config = MaskFormerConfig(backbone="resnet50", use_pretrained_backbone=True, use_timm_backbone=True) # backbone and neck config | ||
model = MaskFormerForInstanceSegmentation(config) # head | ||
``` | ||
|
||
Set `use_timm_backbone=True` and `use_pretrained_backbone=False` to load a randomly initialized timm backbone. | ||
|
||
```python | ||
from transformers import MaskFormerConfig, MaskFormerForInstanceSegmentation | ||
|
||
config = MaskFormerConfig(backbone="resnet50", use_pretrained_backbone=False, use_timm_backbone=True) # backbone and neck config | ||
model = MaskFormerForInstanceSegmentation(config) # head | ||
``` | ||
|
||
You could also load the backbone config and use it to create a `TimmBackbone` or pass it to the model config. Timm backbones will load pretrained weights by default. Set `use_pretrained_backbone=False` to load randomly initialized weights. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accidentally left comments as review comments and I'm too lazy to store / delete / readd
|
||
self.backbone = backbone | ||
self.use_pretrained_backbone = use_pretrained_backbone | ||
self.use_timm_backbone = use_timm_backbone | ||
self.backbone_kwargs = backbone_kwargs | ||
self.num_hidden_layers = None if use_autobackbone else num_hidden_layers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't default to setting these to None
if use_autobackbone
is True
as:
- Even if
is_hybrid
isFalse
, some of these values are needed e.g.image_size
andpatch_size
are needed in patch embeddings which are used in DPTViTEmbeddings initialized here. - It can lead to surprising behaviour e.g. I pass in
DPTConfig(is_hybrid=False, num_hidden_layers=5)
and see thenum_hidden_layers
is then set toNone
.
@@ -208,9 +208,8 @@ def __init__( | |||
if readout_type != "project": | |||
raise ValueError("Readout type must be 'project' when using `DPT-hybrid` mode.") | |||
|
|||
elif backbone_config is not None: | |||
elif backbone is not None or backbone_config is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To specify which backbone to load from a config, the user can do one of two things, either:
- Specify the checkpoint e.g.
backbone=microsoft/resnet-10
- Specify a config e.g.
backbone_config=BitConfg()
We need to be able to support both to enable load_backbone
i.e. loading timm or HF pretrained and randomly initialized architectures.
backbone_model_type = None | ||
if config.backbone is not None: | ||
backbone_model_type = config.backbone | ||
elif config.backbone_config is not None: | ||
backbone_model_type = config.backbone_config.model_type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a chance that backbone and backbone_config are both None by misconfiguration?
It shouldn't happen from a fresh config, as verify_backbone_config_arguments
will check that.
It's possible something like this could happen:
from transformers import ConditionalDetrConfig, ConditionalDetrModel
config = ConditionalDetrConfig()
config.backbone = None
config.backbone_config = None
model = ConditionalDetrModel(config)
as it's not uncommon to modify configs post-creation.
If backbone_model_type will remain None, then if "resnet" in backbone_model_type: will raise an error:
Good point. I'll add in an exception is neither are set here and for the other, similar bits of logic
d9c7d77
to
d40fcd2
Compare
* Enable load HF or tim backbone checkpoints * Fix up * Fix test - pass in proper out_indices * Update docs * Fix tvp tests * Fix doc examples * Fix doc examples * Try to resolve DPT backbone param init * Don't conditionally set to None * Add condition based on whether backbone is defined * Address review comments
* Enable load HF or tim backbone checkpoints * Fix up * Fix test - pass in proper out_indices * Update docs * Fix tvp tests * Fix doc examples * Fix doc examples * Try to resolve DPT backbone param init * Don't conditionally set to None * Add condition based on whether backbone is defined * Address review comments
What does this PR do?
Enables loading HF pretrained model weights for backbones.
verify_backbone_config_arguments
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.