Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Object Detector Training Implementation and Prediction Update #2067

Merged

Conversation

f4str
Copy link
Collaborator

@f4str f4str commented Mar 13, 2023

Description

Implementation of the fit method for the PyTorchObjectDetector, PyTorchFasterRCNN, and PyTorchYolo object detectors. This allows these models to be trained using ART. The notebook in notebooks/poisoning_attack_bad_det_rma.ipynb has been updated to demonstrate training these models. This is a partial implementation of #2058 as the fit method needs to also be implemented for the TensorFlow object detectors.

The predict method in the PyTorchObjectDetector and PyTorchYolo classes has been rewritten to implement the following new features:

  • Cleanup and refactoring so the code is easier to read and understand.
  • The torch.no_grad() scope is used like the PyTorchClassifier to prevent extraneous gradients accumulating in the model which causes slow down.
  • Inputs are now batched so the entire input is not processed at once.
  • The channels_first parameter is now used to determine whether the input needs to be transformed. The default behavior remains the same.

Additionally, the unit tests for the test_pytorch_object_detector.py and test_pytorch_faster_rcnn.py were rewritten in pytest with a new test added for model training.

Type of change

Please check all relevant options.

  • Improvement (non-breaking)
  • Bug fix (non-breaking)
  • New feature (non-breaking)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Testing

Please describe the tests that you ran to verify your changes. Consider listing any relevant details of your test configuration.

  • Tests for PyTorchObjectDetector (rewritten in pytest)
  • Tests for PyTorchFasterRCNN (rewritten in pytest)
  • Tests for PyTorchYolo

Test Configuration:

  • OS
  • Python version
  • ART version or commit number
  • TensorFlow / Keras / PyTorch / MXNet version

Checklist

  • My code follows the style guidelines of this project
  • I have performed a self-review of my own code
  • I have commented my code
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

f4str added 10 commits March 12, 2023 15:21
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
@codecov-commenter
Copy link

codecov-commenter commented Mar 13, 2023

Codecov Report

Merging #2067 (f774359) into dev_1.14.0 (f48d73e) will increase coverage by 3.47%.
The diff coverage is 77.33%.

📣 This organization is not using Codecov’s GitHub App Integration. We recommend you install it so Codecov can continue to function properly for your repositories. Learn more

Impacted file tree graph

@@              Coverage Diff               @@
##           dev_1.14.0    #2067      +/-   ##
==============================================
+ Coverage       77.23%   80.70%   +3.47%     
==============================================
  Files             294      294              
  Lines           26212    26322     +110     
  Branches         4797     4827      +30     
==============================================
+ Hits            20244    21244    +1000     
+ Misses           4822     3914     -908     
- Partials         1146     1164      +18     
Impacted Files Coverage Δ
...estimators/object_detection/pytorch_faster_rcnn.py 100.00% <ø> (+22.22%) ⬆️
art/estimators/object_detection/pytorch_yolo.py 75.43% <75.00%> (+60.98%) ⬆️
...mators/object_detection/pytorch_object_detector.py 76.79% <79.72%> (+64.15%) ⬆️

... and 24 files with indirect coverage changes

@f4str f4str changed the title PyTorch Object Detector Training and Prediction Batching PyTorch Object Detector Training Implementation and Prediction Update Mar 13, 2023
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
@f4str f4str marked this pull request as ready for review March 13, 2023 02:15
@beat-buesser beat-buesser added the improvement Improve implementation label Mar 13, 2023
@beat-buesser beat-buesser added this to the ART 1.14.0 milestone Mar 13, 2023
@beat-buesser beat-buesser self-requested a review March 13, 2023 10:23
@beat-buesser beat-buesser self-assigned this Mar 13, 2023
Copy link
Collaborator

@beat-buesser beat-buesser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @f4str Thank you very much for your pull request. I have added a few review comments, what do you think?

art/estimators/object_detection/pytorch_faster_rcnn.py Outdated Show resolved Hide resolved
art/estimators/object_detection/pytorch_faster_rcnn.py Outdated Show resolved Hide resolved
art/estimators/object_detection/pytorch_faster_rcnn.py Outdated Show resolved Hide resolved
art/estimators/object_detection/pytorch_object_detector.py Outdated Show resolved Hide resolved
art/estimators/object_detection/pytorch_yolo.py Outdated Show resolved Hide resolved
art/estimators/object_detection/pytorch_yolo.py Outdated Show resolved Hide resolved
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
@f4str f4str requested a review from beat-buesser March 13, 2023 20:05
Comment on lines 424 to 427
# Apply postprocessing
predictions = self._apply_postprocessing(preds=results_list, fit=False)

return predictions # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Apply postprocessing
predictions = self._apply_postprocessing(preds=results_list, fit=False)
return predictions # type: ignore
return predictions

Comment on lines 467 to 470
# Apply postprocessing
predictions = self._apply_postprocessing(preds=results_list, fit=False)

return predictions # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Apply postprocessing
predictions = self._apply_postprocessing(preds=results_list, fit=False)
return predictions # type: ignore
return predictions

Comment on lines 467 to 470
# Apply postprocessing
predictions = self._apply_postprocessing(preds=results_list, fit=False)

return predictions # type: ignore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Apply postprocessing
predictions = self._apply_postprocessing(preds=results_list, fit=False)
return predictions # type: ignore
return predictions

Signed-off-by: Farhan Ahmed <Farhan.Ahmed@ibm.com>
@f4str f4str requested a review from beat-buesser March 13, 2023 21:43
Copy link
Collaborator

@beat-buesser beat-buesser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@f4str Looks good to me! Thank you very much.

@beat-buesser beat-buesser merged commit 587783b into Trusted-AI:dev_1.14.0 Mar 15, 2023
@f4str f4str deleted the pytorch-object-detector-training branch March 15, 2023 02:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
improvement Improve implementation
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants