Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CenterNet MobileNetV2 FPN 512x512 not trainable (or a bug of evaluation) #10065

Closed
lisosia opened this issue Jun 12, 2021 · 8 comments
Closed
Assignees
Labels

Comments

@lisosia
Copy link

lisosia commented Jun 12, 2021

Prerequisites

Please answer the following questions for yourself before submitting an issue.

  • [v] I am using the latest TensorFlow Model Garden release and TensorFlow 2.
  • [v] I am reporting the issue to the correct repository. (Model Garden official or research directory)
  • [v] I checked to make sure that this issue has not already been filed.

1. The entire URL of the file you are using

https://github.com/tensorflow/models/tree/master/research/object_detection

2. Describe the bug

A issue of training "CenterNet MobileNetV2 FPN 512x512" while other models trainnable.

I conducted overfit-training test to verify that the model can be trained.
I tested 3 models and only the "CenterNet MobileNetV2" training fails.

3. Steps to reproduce

checkout commit 0c9253b

used configs and create_voc_subset_tfrecord.py
sanitiy_check.zip

  1. I prepared subset of voc dataset (128 image, 20class). data is automatically downloaded by tensorflow_datasets package.
python3 create_voc_subset_tfrecord.py --output_path=sanitiy_check/test.tfrecord
  1. train models. without augmentation, 8000steps, lr~=1e-3, batchsize=6 or 8
    please refer to attached configs
C=<CONFIG_NAME>.config
python3 model_main_tf2.py --pipeline_config_path=sanitiy_check/$C --model_dir=sanitiy_check/model_dir/$C
  1. eval for the the same data used when training

mofify model_lib_v2.py to run evaluation properly

-  for latest_checkpoint in tf.train.checkpoints_iterator(
-      checkpoint_dir, timeout=timeout, min_interval_secs=wait_interval):
+  for latest_checkpoint in [tf.train.latest_checkpoint(checkpoint_dir)]:

then run evaluation

python3 model_main_tf2.py --pipeline_config_path=sanitiy_check/$C --model_dir=sanitiy_check/model_dir/$C --checkpoint_dir=sanitiy_check/model_dir/$C
  1. check mAP and loss. check if mAP@0.75 is around 1.0 and check if val-loss is almost same as train-loss at 8000step

only the result of centernet mobilenetv2 is apparently incorrect.
train loss decreases during training, but val-loss is high and mAP@0.75 is 0.388

### centernet_resnet101_v1_fpn_512x512_coco17_tpu-8.config
- train
I0612 09:34:58.359249 140411785889600 model_lib_v2.py:683] Step 8000 per-step time 0.379s loss=0.455
- eval
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.868
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.983
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.980
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.800
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.850
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.873
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.627
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.888
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.897
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.800
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.876
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.897
I0612 09:38:10.690112 140116521383744 model_lib_v2.py:992] 	+ Loss/total_loss: 0.569873

### centernet_mobilenetv2_fpn_512x512_coco17_tpu-8.config
- train
I0612 07:51:29.272449 139976995292992 model_lib_v2.py:683] Step 8000 per-step time 0.277s loss=0.368
- eval
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.373
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.644
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.388
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.325
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.380
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.396
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.533
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.543
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.471
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.548
I0612 07:52:17.943458 140633233336128 model_lib_v2.py:992] 	+ Loss/total_loss: 4.591526

### ssd_mobilenet_v2_fpnlite_320x320_coco17_tpu-8.config
- train
I0612 08:13:31.177959 140652124112704 model_lib_v2.py:683] Step 8000 per-step time 0.121s loss=0.019
- eval
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.981
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 1.000
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.971
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.984
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.684
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.977
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.987
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 1.000
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.981
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.989
I0612 08:20:40.350211 140584744761152 model_lib_v2.py:992] 	+ Loss/total_loss: 0.019054

4. Expected behavior

  • mAP@0.75 is around 1.0
  • val-loss is almost same as train-loss at 8000step

5. Additional context

although it is not directly related to the subject:

  • centernet_mobienetv2 config should be added to configs/tf2/.
  • I think that the expected behavior of model_main_tf2.py would be evaluating once using latest_checkpoint when --checkpoint_dir specified
  • many people seems to have trouble training models. sanity check like above may be useful for other models

6. System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04 (docker), Ubuntu 20.04(host)
  • Mobile device name if the issue happens on a mobile device: Nan
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): 2.4.1
  • Python version: 3.6.9
  • Bazel version (if compiling from source):
  • GCC/Compiler version (if compiling from source):
  • CUDA/cuDNN version: 11.0
  • GPU model and memory: RTX2070S,8GMem
@lisosia lisosia added models:research models that come under research directory type:bug Bug in the code labels Jun 12, 2021
@lisosia
Copy link
Author

lisosia commented Jun 15, 2021

  • using type: "mobilenet_v2_fpn" instaed of "mobilenet_v2_fpn_sep_conv" makes no difference

  • I call eager_eval_loop() during training loop and found that eval-result (@3000step,@4000step) is also bad (total-loss~=10.0). so the problem is not in checkpoint saving/restoring

  • both of train_input and eval_input image seems to be scale to [-1, 1]

To be honest, I am finding it difficult to debug the codes in this repository.
I would be very happy if the TF Object Detection Team could work on this problem.

I believe centernet is one of the most important architecture for mobile purpose (ex. Next-Generation Pose Detection with MoveNet and TensorFlow.js.

@saikumarchalla

@lisosia
Copy link
Author

lisosia commented Jun 16, 2021

I decided to use size=1 dataset to simplify the things (batch_size set to 1 for train and eval).

Below change eliminate the mismatch betweeen train-loss and eval-loss and I got mAP=1.0 at 1000step.
meta_architectures/center_net_meta_arch.py

class CenterNetMetaArch
  def predict(...
-    features_list = self._feature_extractor(preprocessed_inputs)
+    features_list = self._feature_extractor(preprocessed_inputs, training=False)

So the problem stems from Batchnorm (or Dropout maybe).

I will inspect the way to properly handle the issue later if possible.

  • FreezableBatchNorm (core/freezable_batch_norm.py) is used by mobilenet_v2.
    def call(self, inputs, training=None): training=None is default .
class FreezableBatchNorm(tf.keras.layers.BatchNormalization):
  """Batch normalization layer (Ioffe and Szegedy, 2014).

  This is a `freezable` batch norm layer that supports setting the `training`
  parameter in the __init__ method rather than having to set it either via
  the Keras learning phase or via the `call` method parameter. This layer will
  forward all other parameters to the default Keras `BatchNormalization`
  layer

  This is class is necessary because Object Detection model training sometimes
  requires batch normalization layers to be `frozen` and used as if it was
  evaluation time, despite still training (and potentially using dropout layers)

  Like the default Keras BatchNormalization layer, this will normalize the
  activations of the previous layer at each batch,
  i.e. applies a transformation that maintains the mean activation
  close to 0 and the activation standard deviation close to 1.

  Args:
    training: If False, the layer will normalize using the moving average and
      std. dev, without updating the learned avg and std. dev.
      If None or True, the layer will follow the keras BatchNormalization layer
      strategy of checking the Keras learning phase at `call` time to decide
      what to do.
    **kwargs: The keyword arguments to forward to the keras BatchNormalization
        layer constructor.

@lisosia
Copy link
Author

lisosia commented Jun 17, 2021

I finally found that eval loss is big just because learned moving average of μ,σ are not yet close to the μ,σ of the batch (which is unique one in this experiment),
at the evaluation after 1000 step.

The difference of ssd and centernet is probably just a batchnorm decay value:
SSD use decay=0.997 while centernet uses mobilenet_v2 defaullt decay=0.9997 (pytorch I'm familiar with defaults to 0.9)
when I set decay=0.9997 for ssd, the same issue occurs. and decay=0.90 does not cause the issue.


default_batchnorm_momentum=0.9997,

The problem would go away after sufficient steps.

When evaluation, Batchnorm probably behave as inference mode (as expected) without any code modifications in this repo.


I hope someone in the know can confirm above conclusion.

I am NOT 100% sure at this point:

When evaluation, Batchnorm probably behave as inference mode (as expected) without any code modifications in this repo

@jaeyounkim jaeyounkim added models:research:odapi ODAPI and removed models:research models that come under research directory labels Jun 25, 2021
@Swazir9449
Copy link

@lisosia I don't have the solution to this problem. But I was wondering if you know how to encode key points for this model as they are supported for this one. I cannot find any documentation for this.

@lisosia
Copy link
Author

lisosia commented Jul 6, 2021

@Swazir9449

First of all, object detection and keypoint detection are two different things and need to be distinguished.

I have not used this repository for keypoint detection, but as far as I know, there is no documentation for keypoint detection.
Therefore, you need to check the config file or the script for dataset creation to understand what you need to do.

For example

  • configs/tf2/centernet_resnet50_v2_512x512_kpts_coco17_tpu-8.config
  • dataset_tools/create_coco_tf_record.py

If you want to know about centernet itself, check the original paper.

@lisosia
Copy link
Author

lisosia commented Jul 6, 2021

My conclusion is probably correct, and since there seems to be no response, I will close the issue.

I think the proper way is to add an option to protos/center_net.proto to specify the momentum, but in my case it was enough to override the value in models/keras_models/mobilenet_v2.py.

@lisosia lisosia closed this as completed Jul 6, 2021
@google-ml-butler
Copy link

Are you satisfied with the resolution of your issue?
Yes
No

@Annieliaquat
Copy link

Is this issue solved now?? CenterNet MobileNetV2 FPN 512x512 is trainable now?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

8 participants