Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect calculations in GridAnchorGenerator, gridAnchorLayer for rectangular inputs #1563

Closed
marcbelmont opened this issue Oct 19, 2021 · 3 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@marcbelmont
Copy link

marcbelmont commented Oct 19, 2021

Description

If the input layer of an object detector is not square, libnvinfer_plugin does not produce the correct bounding boxes. An incomplete fix was committed to master with #679 by @rajeevsrao . But parts of it are not in master anymore. The issue was also discussed in #807.

Environment

  • NVIDIA Jetson TX2
    • Jetpack 4.6 [L4T 32.6.1]
    • NV Power Mode: MAXP_CORE_ARM - Type: 3
    • jetson_stats.service: active
  • Libraries:
    • CUDA: 10.2.300
    • cuDNN: 8.2.1.32
    • TensorRT: 8.0.1.6
    • Visionworks: 1.6.0.501
    • OpenCV: 4.1.1 compiled CUDA: NO
    • VPI: ii libnvvpi1 1.1.12 arm64 NVIDIA Vision Programming Interface library
    • Vulkan: 1.2.70

Relevant Files and Fix

The following changes fix the issue.

modified   plugin/common/kernels/gridAnchorLayer.cu                                                                                                                                        
@@ -34,8 +34,10 @@ __launch_bounds__(nthdsPerCTA) __global__ void gridAnchorKernel(const GridAnchor                                                                                        
      * the image Every coordinate will go back to the pixel coordinates in the input image if being multiplied by                                                                         
      * image_input_size Here we implicitly assumes the image input and feature map are square                                                                                             
      */                                                                                                                                                                                   
-    float anchorStride = (1.0 / param.H);                                                                                                                                                 
-    float anchorOffset = 0.5 * anchorStride;                                                                                                                                              
+    float anchorStrideH = (1.0 / param.H);                                                                                                                                                
+    float anchorOffsetH = 0.5 * anchorStrideH;                                                                                                                                            
+    float anchorStrideW = (1.0 / param.W);                                                                                                                                                
+    float anchorOffsetW = 0.5 * anchorStrideW;                                                                                                                                            
                                                                                                                                                                                           
     int tid = blockIdx.x * blockDim.x + threadIdx.x;                                                                                                                                      
     if (tid >= dim)                                                                                                                                                                       
@@ -47,8 +49,8 @@ __launch_bounds__(nthdsPerCTA) __global__ void gridAnchorKernel(const GridAnchor                                                                                         
     const int h = currIndex / param.W;                                                                                                                                                    
                                                                                                                                                                                           
     // Center coordinates                                                                                                                                                                 
-    float yC = h * anchorStride + anchorOffset;                                                                                                                                           
-    float xC = w * anchorStride + anchorOffset;                                                                                                                                           
+    float yC = h * anchorStrideH + anchorOffsetH;                                                                                                                                         
+    float xC = w * anchorStrideW + anchorOffsetW;                                                                                                                                         

modified   plugin/gridAnchorPlugin/gridAnchorPlugin.cpp                                                                                                                                    
@@ -109,11 +109,13 @@ GridAnchorGenerator::GridAnchorGenerator(const GridAnchorParameters* paramIn, in                                                                                     
                                                                                                                                                                                           
         std::vector<float> tmpWidths;                                                                                                                                                     
         std::vector<float> tmpHeights;                                                                                                                                                    
+        float featMapAspectRatio = (float) (mParam[0].H) / (float) (mParam[0].W);                                                                                                         
+        // TODO: calculate the ratio with the input layer height and width instead.                                                                                                             
         // Calculate the width and height of the prior boxes                                                                                                                              
         for (int i = 0; i < mNumPriors[id]; i++)                                                                                                                                          
         {                                                                                                                                                                                 
             float sqrt_AR = sqrt(aspect_ratios[i]);                                                                                                                                       
-            tmpWidths.push_back(scales[i] * sqrt_AR);                                                                                                                                     
+            tmpWidths.push_back(scales[i] * sqrt_AR * featMapAspectRatio);                                                                                                                
             tmpHeights.push_back(scales[i] / sqrt_AR);                                                                                                                                    
         }                                                                                                                                                                                 

Steps To Reproduce

Example of how the plugin is used when doing graph surgeon:

gs.create_plugin_node(
        name="MultipleGridAnchorGenerator",
        op="GridAnchorRect_TRT",
        minSize=0.2,
        maxSize=0.95,
        aspectRatios=[1.0, 2.0, 0.5, 3.0, 0.33],
        variance=[0.1, 0.1, 0.2, 0.2],
        featureMapShapes=[40, 23, 20, 12, 10, 6, 5, 3, 3, 2, 2, 1],
        numLayers=6,
)
@marcbelmont marcbelmont changed the title Incorrect calculations in GridAnchorGenerator, gridAnchorLayer for rectangular inputs. Incorrect calculations in GridAnchorGenerator, gridAnchorLayer for rectangular inputs Oct 19, 2021
@oxana-nvidia oxana-nvidia self-assigned this May 26, 2022
@oxana-nvidia
Copy link
Collaborator

Hi @marcbelmont,
Thanks for reporting the issue!

I see two issues with proposed solution:

  1. potential performance impact
  2. it changes plugin semantics

I've filed internal ticket for TensorRT developers to investigate if the changes can be added to the code.
Internal bug number: 3659884
cc @rajeevsrao

@oxana-nvidia oxana-nvidia added the triaged Issue has been triaged by maintainers label May 26, 2022
@oxana-nvidia
Copy link
Collaborator

Hi @marcbelmont,
We've integrated proposed change to TensorRT. It will be in TRT 8.5 GA (aka TRT 8.5.1)

@ttyio
Copy link
Collaborator

ttyio commented Mar 1, 2023

Closing since no activity for more than 3 weeks, please reopen if you still have question, thanks!

@ttyio ttyio closed this as completed Mar 1, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants