-
Notifications
You must be signed in to change notification settings - Fork 231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add connector components and FitNet #207
Conversation
Codecov Report
@@ Coverage Diff @@
## dev-1.x #207 +/- ##
==========================================
- Coverage 0.68% 0.67% -0.02%
==========================================
Files 115 119 +4
Lines 4216 4321 +105
Branches 659 675 +16
==========================================
Hits 29 29
- Misses 4182 4287 +105
Partials 5 5
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
class BaseConnector(nn.Module, metaclass=ABCMeta): | ||
"""Base class of connectors. | ||
|
||
Connector is mainly used for distill, it usually converts the channel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
distill -> distillation
import torch.nn as nn | ||
|
||
|
||
class BaseConnector(nn.Module, metaclass=ABCMeta): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to inherit from BaseModule?
Then, self.init_parameters()
could be rewritten by init_weights()
and init_cfg
|
||
|
||
@MODELS.register_module() | ||
class BNConnector(BaseConnector): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-> ConvBNConnnector
out_channel: int, | ||
) -> None: | ||
super().__init__() | ||
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use build_conv_layer
to build a conv layer
out_channel: int, | ||
) -> None: | ||
super().__init__() | ||
self.conv = nn.Conv2d( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use build_conv_layer
to build a conv layer
build_norm_layer
to build a norm layer
|
||
|
||
@MODELS.register_module() | ||
class ReLUConnector(BaseConnector): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
->ConvBNReLUConnector?
out_channel: int, | ||
) -> None: | ||
super().__init__() | ||
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same
@@ -38,6 +38,14 @@ class ConfigurableDistiller(BaseDistiller): | |||
distill_deliveries (dict, optional): Config for multiple deliveries. A | |||
distill algorithm may have more than one delivery. Defaults to | |||
None. | |||
student_connectors (dict, optional): Config for multiple connectors. A |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explain the key mapping relations between connectors, distill_losses, and loss_forward_mappings in Note
below.
distill_connecotrs = nn.ModuleDict() | ||
if connectors: | ||
for loss_name, connector_cfg in connectors.items(): | ||
assert loss_name in self.distill_losses |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add assert_str
if connectors: | ||
for loss_name, connector_cfg in connectors.items(): | ||
assert loss_name in self.distill_losses | ||
assert loss_name not in distill_connecotrs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need?
…replace nn.bn to build_norm_layer.
stride=1, | ||
padding=0, | ||
bias=False) | ||
_, self.bn = build_norm_layer(dict(type='BN'), out_channel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
norm_cfg
super().__init__(init_cfg) | ||
self.conv = build_conv_layer( | ||
conv_cfg, in_channel, out_channel, kernel_size=1) | ||
_, self.bn = build_norm_layer(dict(type='BN'), out_channel) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
norm_cfg
…ctor to realize automatically invocation.
I have three comments about this PR:
For comment 2, I think you can refer to the mapping method between recorder, loss, and feature |
…or.models.connector to mmrazor.models.architectures. 3.Merge stu_connectors and tea_connectors into connectors, and call connectors by their connector_name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM if the latest comments are fixed.
from_student=True, | ||
recorder='bb_s4', | ||
record_idx=1, | ||
connector_name='loss_s4_sfeat'), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
connector_name -> connector
@@ -148,25 +182,30 @@ def get_record(self, | |||
recorder: str, | |||
from_student: bool, | |||
record_idx: int = 0, | |||
data_idx: Optional[int] = None) -> List: | |||
data_idx: Optional[int] = None, | |||
connector_name: Optional[str] = None) -> List: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
connector_name -> connector
…at each connector must be in connectors.
* [WIP] Refactor v2.0 (open-mmlab#163) * Refactor backend wrapper * Refactor mmdet.inference * Fix * merge * refactor utils * Use deployer and deploy_model to manage pipeline * Resolve comments * Add a real inference api function * rename wrappers * Set execute to private method * Rename deployer deploy_model * Refactor task * remove type hint * lint * Resolve comments * resolve comments * lint * docstring * [Fix]: Fix bugs in details in refactor branch (open-mmlab#192) * [WIP] Refactor v2.0 (open-mmlab#163) * Refactor backend wrapper * Refactor mmdet.inference * Fix * merge * refactor utils * Use deployer and deploy_model to manage pipeline * Resolve comments * Add a real inference api function * rename wrappers * Set execute to private method * Rename deployer deploy_model * Refactor task * remove type hint * lint * Resolve comments * resolve comments * lint * docstring * Fix errors * lint * resolve comments * fix bugs * conflict * lint and typo * Resolve comment * refactor mmseg (open-mmlab#201) * support mmseg * fix docstring * fix docstring * [Refactor]: Get the count of backend files (open-mmlab#202) * Fix backend files * resolve comments * lint * Fix ncnn * [Refactor]: Refactor folders of mmdet (open-mmlab#200) * Move folders * lint * test object detection model * lint * reset changes * fix openvino * resolve comments * __init__.py * Fix path * [Refactor]: move mmseg (open-mmlab#206) * [Refactor]: Refactor mmedit (open-mmlab#205) * feature mmedit * edit2.0 * edit * refactor mmedit * fix __init__.py * fix __init__ * fix formai * fix comment * fix comment * Fix wrong func_name of ConvFCBBoxHead (open-mmlab#209) * [Refactor]: Refactor mmdet unit test (open-mmlab#207) * Move folders * lint * test object detection model * lint * WIP * remove print * finish unit test * Fix tests * resolve comments * Add mask test * lint * resolve comments * Refine cfg file * Move files * add files * Fix path * [Unittest]: Refine the unit tests in mmdet open-mmlab#214 * [Refactor] refactor mmocr to mmdeploy/codebase (open-mmlab#213) * refactor mmocr to mmdeploy/codebase * fix docstring of show_result * fix docstring of visualize * refine docstring * replace print with logging * refince codes * resolve comments * resolve comments * [Refactor]: mmseg tests (open-mmlab#210) * refactor mmseg tests * rename test_codebase * update * add model.py * fix * [Refactor] Refactor mmcls and the package (open-mmlab#217) * refactor mmcls * fix yapf * fix isort * refactor-mmcls-package * fix print to logging * fix docstrings according to others comments * fix comments * fix comments * fix allentdans comment in pr215 * remove mmocr init * [Refactor] Refactor mmedit tests (open-mmlab#212) * feature mmedit * edit2.0 * edit * refactor mmedit * fix __init__.py * fix __init__ * fix formai * fix comment * fix comment * buff * edit test and code refactor * refactor dir * refactor tests/mmedit * fix docstring * add test coverage * fix lint * fix comment * fix comment * Update typehint (open-mmlab#216) * update type hint * update docstring * update * remove file * fix ppl * Refine get_predefined_partition_cfg * fix tensorrt version > 8 * move parse_cuda_device_id to device.py * Fix cascade * onnx2ncnn docstring Co-authored-by: Yifan Zhou <singlezombie@163.com> Co-authored-by: RunningLeon <maningsheng@sensetime.com> Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com> Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com> Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
* introduction to model deployment * remove comments * trim trailing whitespace * add link
Hi @wilxy !First of all, we want to express our gratitude for your significant PR in the MMRazor project. Your contribution is highly appreciated, and we are grateful for your efforts in helping improve this open-source project during your personal time. We believe that many developers will benefit from your PR. We would also like to invite you to join our Special Interest Group (SIG) private channel on Discord, where you can share your experiences, ideas, and build connections with like-minded peers. To join the SIG channel, simply message moderator— OpenMMLab on Discord or briefly share your open-source contributions in the #introductions channel and we will assist you. Look forward to seeing you there! Join us :https://discord.gg/UjgXkPWNqA If you have WeChat account,welcome to join our community on WeChat. You can add our assistant :openmmlabwx. Please add "mmsig + Github ID" as a remark when adding friends:) |
New Features
Add connector components. Connector is mainly used for distillation, it usually converts the channel number of input feature to align features of student and teacher. In this PR,
BaseConnector
and three general connectors (SingleConvConnector
,ConvBNConnector
andConvBNReLUConnector
) are added.ConfigurableDistiller
adds two connector related parameters (student_connectors
andteacher_connectors
) and functionbuild_connectors
to build connectors.Add FitNet distillation algorithm and its README (where the model and log paths are empty, waiting for the training results of the unified benchmark before the release).
Add L2Loss. It can be used as the following example.
from mmrazor.models import L2Loss
loss_cfg = dict(loss_weight=10, normalize=False)
l2loss = L2Loss(**loss_cfg)
Improvement
Add
student_trainable
parameter inSingleTeacherDistill
to control whether the student model is trainable.Add
calculate_student_loss
attribute inSingleTeacherDistill
to control whether the original task loss to be calculated and used to update student model.Add
test_connectors
andtest_losses
intests/test_models
.