Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add connector components and FitNet #207

Merged
merged 6 commits into from
Jul 28, 2022

Conversation

wilxy
Copy link
Contributor

@wilxy wilxy commented Jul 26, 2022

New Features

  1. Add connector components. Connector is mainly used for distillation, it usually converts the channel number of input feature to align features of student and teacher. In this PR, BaseConnector and three general connectors (SingleConvConnector, ConvBNConnector and ConvBNReLUConnector) are added.

  2. ConfigurableDistiller adds two connector related parameters ( student_connectors and teacher_connectors) and function build_connectors to build connectors.

  3. Add FitNet distillation algorithm and its README (where the model and log paths are empty, waiting for the training results of the unified benchmark before the release).

  4. Add L2Loss. It can be used as the following example.
    from mmrazor.models import L2Loss
    loss_cfg = dict(loss_weight=10, normalize=False)
    l2loss = L2Loss(**loss_cfg)

Improvement

  1. Add student_trainable parameter in SingleTeacherDistill to control whether the student model is trainable.

  2. Add calculate_student_loss attribute in SingleTeacherDistill to control whether the original task loss to be calculated and used to update student model.

  3. Add test_connectors and test_losses in tests/test_models.

@wilxy wilxy requested review from sunnyxiaohu and pppppM July 26, 2022 13:32
@codecov
Copy link

codecov bot commented Jul 26, 2022

Codecov Report

Merging #207 (d75d120) into dev-1.x (6987511) will decrease coverage by 0.01%.
The diff coverage is 0.00%.

@@            Coverage Diff             @@
##           dev-1.x    #207      +/-   ##
==========================================
- Coverage     0.68%   0.67%   -0.02%     
==========================================
  Files          115     119       +4     
  Lines         4216    4321     +105     
  Branches       659     675      +16     
==========================================
  Hits            29      29              
- Misses        4182    4287     +105     
  Partials         5       5              
Flag Coverage Δ
unittests 0.67% <0.00%> (-0.02%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
...hms/distill/configurable/single_teacher_distill.py 0.00% <0.00%> (ø)
mmrazor/models/architectures/__init__.py 0.00% <0.00%> (ø)
...mrazor/models/architectures/connectors/__init__.py 0.00% <0.00%> (ø)
.../models/architectures/connectors/base_connector.py 0.00% <0.00%> (ø)
...dels/architectures/connectors/general_connector.py 0.00% <0.00%> (ø)
mmrazor/models/distillers/base_distiller.py 0.00% <0.00%> (ø)
...mrazor/models/distillers/configurable_distiller.py 0.00% <0.00%> (ø)
mmrazor/models/losses/__init__.py 0.00% <0.00%> (ø)
mmrazor/models/losses/l2_loss.py 0.00% <0.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 6987511...d75d120. Read the comment docs.

class BaseConnector(nn.Module, metaclass=ABCMeta):
"""Base class of connectors.

Connector is mainly used for distill, it usually converts the channel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distill -> distillation

import torch.nn as nn


class BaseConnector(nn.Module, metaclass=ABCMeta):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to inherit from BaseModule?
Then, self.init_parameters() could be rewritten by init_weights() and init_cfg



@MODELS.register_module()
class BNConnector(BaseConnector):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> ConvBNConnnector

out_channel: int,
) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use build_conv_layer to build a conv layer

out_channel: int,
) -> None:
super().__init__()
self.conv = nn.Conv2d(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use build_conv_layer to build a conv layer
build_norm_layer to build a norm layer



@MODELS.register_module()
class ReLUConnector(BaseConnector):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

->ConvBNReLUConnector?

out_channel: int,
) -> None:
super().__init__()
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same

@@ -38,6 +38,14 @@ class ConfigurableDistiller(BaseDistiller):
distill_deliveries (dict, optional): Config for multiple deliveries. A
distill algorithm may have more than one delivery. Defaults to
None.
student_connectors (dict, optional): Config for multiple connectors. A
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain the key mapping relations between connectors, distill_losses, and loss_forward_mappings in Note below.

distill_connecotrs = nn.ModuleDict()
if connectors:
for loss_name, connector_cfg in connectors.items():
assert loss_name in self.distill_losses
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add assert_str

if connectors:
for loss_name, connector_cfg in connectors.items():
assert loss_name in self.distill_losses
assert loss_name not in distill_connecotrs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need?

@wilxy wilxy changed the title Fix spelling mistakes [Feature] Add connector components and FitNet Jul 27, 2022
@wilxy wilxy requested a review from sunnyxiaohu July 27, 2022 07:30
stride=1,
padding=0,
bias=False)
_, self.bn = build_norm_layer(dict(type='BN'), out_channel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

norm_cfg

super().__init__(init_cfg)
self.conv = build_conv_layer(
conv_cfg, in_channel, out_channel, kernel_size=1)
_, self.bn = build_norm_layer(dict(type='BN'), out_channel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

norm_cfg

@pppppM
Copy link
Collaborator

pppppM commented Jul 28, 2022

I have three comments about this PR:

  1. calculate_student_loss should be an attribute of the distillation algorithm, not the distiller
  2. The current mapping between connectors and features cannot handle the loss, which has multiple inputs and each input has a different connector.
  3. mmrazor.models.connector can be moved to mmrazor.models.architectures

For comment 2, I think you can refer to the mapping method between recorder, loss, and feature

…or.models.connector to mmrazor.models.architectures. 3.Merge stu_connectors and tea_connectors into connectors, and call connectors by their connector_name.
Copy link
Collaborator

@pppppM pppppM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM if the latest comments are fixed.

from_student=True,
recorder='bb_s4',
record_idx=1,
connector_name='loss_s4_sfeat'),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connector_name -> connector

@@ -148,25 +182,30 @@ def get_record(self,
recorder: str,
from_student: bool,
record_idx: int = 0,
data_idx: Optional[int] = None) -> List:
data_idx: Optional[int] = None,
connector_name: Optional[str] = None) -> List:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

connector_name -> connector

@sunnyxiaohu sunnyxiaohu merged commit 63d46b5 into open-mmlab:dev-1.x Jul 28, 2022
humu789 pushed a commit to humu789/mmrazor that referenced this pull request Feb 13, 2023
* [WIP] Refactor v2.0 (open-mmlab#163)

* Refactor backend wrapper

* Refactor mmdet.inference

* Fix

* merge

* refactor utils

* Use deployer and deploy_model to manage pipeline

* Resolve comments

* Add a real inference api function

* rename wrappers

* Set execute to private method

* Rename deployer deploy_model

* Refactor task

* remove type hint

* lint

* Resolve comments

* resolve comments

* lint

* docstring

* [Fix]: Fix bugs in details in refactor branch (open-mmlab#192)

* [WIP] Refactor v2.0 (open-mmlab#163)

* Refactor backend wrapper

* Refactor mmdet.inference

* Fix

* merge

* refactor utils

* Use deployer and deploy_model to manage pipeline

* Resolve comments

* Add a real inference api function

* rename wrappers

* Set execute to private method

* Rename deployer deploy_model

* Refactor task

* remove type hint

* lint

* Resolve comments

* resolve comments

* lint

* docstring

* Fix errors

* lint

* resolve comments

* fix bugs

* conflict

* lint and typo

* Resolve comment

* refactor mmseg (open-mmlab#201)

* support mmseg

* fix docstring

* fix docstring

* [Refactor]: Get the count of backend files (open-mmlab#202)

* Fix backend files

* resolve comments

* lint

* Fix ncnn

* [Refactor]: Refactor folders of mmdet (open-mmlab#200)

* Move folders

* lint

* test object detection model

* lint

* reset changes

* fix openvino

* resolve comments

* __init__.py

* Fix path

* [Refactor]: move mmseg (open-mmlab#206)

* [Refactor]: Refactor mmedit (open-mmlab#205)

* feature mmedit

* edit2.0

* edit

* refactor mmedit

* fix __init__.py

* fix __init__

* fix formai

* fix comment

* fix comment

* Fix wrong func_name of ConvFCBBoxHead (open-mmlab#209)

* [Refactor]: Refactor mmdet unit test (open-mmlab#207)

* Move folders

* lint

* test object detection model

* lint

* WIP

* remove print

* finish unit test

* Fix tests

* resolve comments

* Add mask test

* lint

* resolve comments

* Refine cfg file

* Move files

* add files

* Fix path

* [Unittest]: Refine the unit tests in mmdet open-mmlab#214

* [Refactor] refactor mmocr to mmdeploy/codebase (open-mmlab#213)

* refactor mmocr to mmdeploy/codebase

* fix docstring of show_result

* fix docstring of visualize

* refine docstring

* replace print with logging

* refince codes

* resolve comments

* resolve comments

* [Refactor]: mmseg  tests (open-mmlab#210)

* refactor mmseg tests

* rename test_codebase

* update

* add model.py

* fix

* [Refactor] Refactor mmcls and the package (open-mmlab#217)

* refactor mmcls

* fix yapf

* fix isort

* refactor-mmcls-package

* fix print to logging

* fix docstrings according to others comments

* fix comments

* fix comments

* fix allentdans comment in pr215

* remove mmocr init

* [Refactor] Refactor mmedit tests (open-mmlab#212)

* feature mmedit

* edit2.0

* edit

* refactor mmedit

* fix __init__.py

* fix __init__

* fix formai

* fix comment

* fix comment

* buff

* edit test and code refactor

* refactor dir

* refactor tests/mmedit

* fix docstring

* add test coverage

* fix lint

* fix comment

* fix comment

* Update typehint (open-mmlab#216)

* update type hint

* update docstring

* update

* remove file

* fix ppl

* Refine get_predefined_partition_cfg

* fix tensorrt version > 8

* move parse_cuda_device_id to device.py

* Fix cascade

* onnx2ncnn docstring

Co-authored-by: Yifan Zhou <singlezombie@163.com>
Co-authored-by: RunningLeon <maningsheng@sensetime.com>
Co-authored-by: VVsssssk <88368822+VVsssssk@users.noreply.github.com>
Co-authored-by: AllentDan <41138331+AllentDan@users.noreply.github.com>
Co-authored-by: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com>
humu789 pushed a commit to humu789/mmrazor that referenced this pull request Feb 13, 2023
* introduction to model deployment

* remove comments

* trim trailing whitespace

* add link
@OpenMMLab-Assistant001
Copy link

Hi @wilxy !First of all, we want to express our gratitude for your significant PR in the MMRazor project. Your contribution is highly appreciated, and we are grateful for your efforts in helping improve this open-source project during your personal time. We believe that many developers will benefit from your PR.

We would also like to invite you to join our Special Interest Group (SIG) private channel on Discord, where you can share your experiences, ideas, and build connections with like-minded peers. To join the SIG channel, simply message moderator— OpenMMLab on Discord or briefly share your open-source contributions in the #introductions channel and we will assist you. Look forward to seeing you there! Join us :https://discord.gg/UjgXkPWNqA

If you have WeChat account,welcome to join our community on WeChat. You can add our assistant :openmmlabwx. Please add "mmsig + Github ID" as a remark when adding friends:)
Thank you again for your contribution!❤

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants