Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mobile SSD models are expected to have exactly 4 outputs, found 2 #5550

Open
libofei2004 opened this issue Jul 30, 2024 · 0 comments
Open

Mobile SSD models are expected to have exactly 4 outputs, found 2 #5550

libofei2004 opened this issue Jul 30, 2024 · 0 comments
Assignees
Labels
os:linux-non-arm Issues on linux distributions which run on x86-64 architecture. DOES NOT include ARM devices. platform:android Issues with Android as Platform platform:python MediaPipe Python issues task:object detection Issues related to Object detection: Track and label objects in images and video. type:modelmaker Issues related to creation of custom on-device ML solutions type:support General questions

Comments

@libofei2004
Copy link

Have I written custom code (as opposed to using a stock example script provided in MediaPipe)

None

OS Platform and Distribution

Ubuntu 22 in wsl2 , android 12

Python Version

3.10

MediaPipe Model Maker version

2.0.4.1

Task name (e.g. Image classification, Gesture recognition etc.)

object detector

Describe the actual behavior

I use mediapipe_model_maker 2.0.4.1 to train an model and use it in an android programme, but it can't run and throws exception.

Describe the expected behaviour

the android programme throws: java.lang.IllegalArgumentException: Error occurred when initializing ObjectDetector: Mobile SSD models are expected to have exactly 4 outputs, found 2

Standalone code/steps you may have used to try to get what you need

1.I trained a tflite model with mediapipe_model_maker.
the code is:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 
from mediapipe_model_maker import object_detector

train_dataset_path = '/mnt/d/workspace/imgupload/img/selected1/bt3/train/shot'
validation_dataset_path = '/mnt/d/workspace/imgupload/img/selected1/bt3/train/shot'
cache_dir = '/mnt/d/workspace/imgupload/img/selected1/tmp'

train_data = object_detector.Dataset.from_pascal_voc_folder(
    train_dataset_path,
    cache_dir=cache_dir)

validate_data = object_detector.Dataset.from_pascal_voc_folder(
    validation_dataset_path,
    cache_dir=cache_dir)

hparams = object_detector.HParams(batch_size=8, learning_rate=0.3, epochs=50, export_dir='exported_model')
options = object_detector.ObjectDetectorOptions(
    supported_model=object_detector.SupportedModels.MOBILENET_V2,
    hparams=hparams)

model = object_detector.ObjectDetector.create(
    train_data=train_data,
    validation_data=validate_data,
    options=options)

loss, coco_metrics = model.evaluate(validate_data, batch_size=4)
print(f"Validation loss: {loss}")
print(f"Validation coco metrics: {coco_metrics}")
model.export_model('dogs.tflite')


2.I use the model in an android programme, the code is from:
https://github.com/tensorflow/examples/tree/master/lite/examples/object_detection/android

The Main code is :
package org.tensorflow.lite.examples.detection;

import android.graphics.Bitmap;
import android.graphics.Bitmap.Config;
import android.graphics.Canvas;
import android.graphics.Color;
import android.graphics.Matrix;
import android.graphics.Paint;
import android.graphics.Paint.Style;
import android.graphics.RectF;
import android.graphics.Typeface;
import android.media.ImageReader.OnImageAvailableListener;
import android.os.SystemClock;
import android.util.Size;
import android.util.TypedValue;
import android.widget.Toast;

import com.example.namespace.R;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.tensorflow.lite.examples.detection.customview.OverlayView;
import org.tensorflow.lite.examples.detection.customview.OverlayView.DrawCallback;
import org.tensorflow.lite.examples.detection.env.BorderedText;
import org.tensorflow.lite.examples.detection.env.ImageUtils;
import org.tensorflow.lite.examples.detection.env.Logger;
import org.tensorflow.lite.examples.detection.tflite.Detector;
import org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel;
import org.tensorflow.lite.examples.detection.tracking.MultiBoxTracker;

/**
 * An activity that uses a TensorFlowMultiBoxDetector and ObjectTracker to detect and then track
 * objects.
 */
public class DetectorActivity extends CameraActivity implements OnImageAvailableListener {
  private static final Logger LOGGER = new Logger();

  // Configuration values for the prepackaged SSD model.
  private static final int TF_OD_API_INPUT_SIZE = 300;
  private static final boolean TF_OD_API_IS_QUANTIZED = true;
  //private static final String TF_OD_API_MODEL_FILE = "detect.tflite";
  private static final String TF_OD_API_MODEL_FILE = "dogs.tflite";
  private static final String TF_OD_API_LABELS_FILE = "labelmap.txt";
  private static final DetectorMode MODE = DetectorMode.TF_OD_API;
  // Minimum detection confidence to track a detection.
  private static final float MINIMUM_CONFIDENCE_TF_OD_API = 0.5f;
  private static final boolean MAINTAIN_ASPECT = false;
  private static final Size DESIRED_PREVIEW_SIZE = new Size(640, 480);
  private static final boolean SAVE_PREVIEW_BITMAP = false;
  private static final float TEXT_SIZE_DIP = 10;
  OverlayView trackingOverlay;
  private Integer sensorOrientation;

  private Detector detector;

  private long lastProcessingTimeMs;
  private Bitmap rgbFrameBitmap = null;
  private Bitmap croppedBitmap = null;
  private Bitmap cropCopyBitmap = null;

  private boolean computingDetection = false;

  private long timestamp = 0;

  private Matrix frameToCropTransform;
  private Matrix cropToFrameTransform;

  private MultiBoxTracker tracker;

  private BorderedText borderedText;

  @Override
  public void onPreviewSizeChosen(final Size size, final int rotation) {
    final float textSizePx =
        TypedValue.applyDimension(
            TypedValue.COMPLEX_UNIT_DIP, TEXT_SIZE_DIP, getResources().getDisplayMetrics());
    borderedText = new BorderedText(textSizePx);
    borderedText.setTypeface(Typeface.MONOSPACE);

    tracker = new MultiBoxTracker(this);

    int cropSize = TF_OD_API_INPUT_SIZE;

    try {
      detector =
          TFLiteObjectDetectionAPIModel.create(
              this,
              TF_OD_API_MODEL_FILE,
              TF_OD_API_LABELS_FILE,
              TF_OD_API_INPUT_SIZE,
              TF_OD_API_IS_QUANTIZED);
      cropSize = TF_OD_API_INPUT_SIZE;
    } catch (final IOException e) {
      e.printStackTrace();
      LOGGER.e(e, "Exception initializing Detector!");
      Toast toast =
          Toast.makeText(
              getApplicationContext(), "Detector could not be initialized", Toast.LENGTH_SHORT);
      toast.show();
      finish();
    }

    previewWidth = size.getWidth();
    previewHeight = size.getHeight();

    sensorOrientation = rotation - getScreenOrientation();
    LOGGER.i("Camera orientation relative to screen canvas: %d", sensorOrientation);

    LOGGER.i("Initializing at size %dx%d", previewWidth, previewHeight);
    rgbFrameBitmap = Bitmap.createBitmap(previewWidth, previewHeight, Config.ARGB_8888);
    croppedBitmap = Bitmap.createBitmap(cropSize, cropSize, Config.ARGB_8888);

    frameToCropTransform =
        ImageUtils.getTransformationMatrix(
            previewWidth, previewHeight,
            cropSize, cropSize,
            sensorOrientation, MAINTAIN_ASPECT);

    cropToFrameTransform = new Matrix();
    frameToCropTransform.invert(cropToFrameTransform);

    trackingOverlay = (OverlayView) findViewById(R.id.tracking_overlay);
    trackingOverlay.addCallback(
        new DrawCallback() {
          @Override
          public void drawCallback(final Canvas canvas) {
            tracker.draw(canvas);
            if (isDebug()) {
              tracker.drawDebug(canvas);
            }
          }
        });

    tracker.setFrameConfiguration(previewWidth, previewHeight, sensorOrientation);
  }

  @Override
  protected void processImage() {
    ++timestamp;
    final long currTimestamp = timestamp;
    trackingOverlay.postInvalidate();

    // No mutex needed as this method is not reentrant.
    if (computingDetection) {
      readyForNextImage();
      return;
    }
    computingDetection = true;
    LOGGER.i("Preparing image " + currTimestamp + " for detection in bg thread.");

    rgbFrameBitmap.setPixels(getRgbBytes(), 0, previewWidth, 0, 0, previewWidth, previewHeight);

    readyForNextImage();

    final Canvas canvas = new Canvas(croppedBitmap);
    canvas.drawBitmap(rgbFrameBitmap, frameToCropTransform, null);
    // For examining the actual TF input.
    if (SAVE_PREVIEW_BITMAP) {
      ImageUtils.saveBitmap(croppedBitmap);
    }

    runInBackground(
        new Runnable() {
          @Override
          public void run() {
            LOGGER.i("Running detection on image " + currTimestamp);
            final long startTime = SystemClock.uptimeMillis();
            final List<Detector.Recognition> results = detector.recognizeImage(croppedBitmap);
            lastProcessingTimeMs = SystemClock.uptimeMillis() - startTime;

            cropCopyBitmap = Bitmap.createBitmap(croppedBitmap);
            final Canvas canvas = new Canvas(cropCopyBitmap);
            final Paint paint = new Paint();
            paint.setColor(Color.RED);
            paint.setStyle(Style.STROKE);
            paint.setStrokeWidth(2.0f);

            float minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
            switch (MODE) {
              case TF_OD_API:
                minimumConfidence = MINIMUM_CONFIDENCE_TF_OD_API;
                break;
            }

            final List<Detector.Recognition> mappedRecognitions =
                new ArrayList<Detector.Recognition>();

            for (final Detector.Recognition result : results) {
              final RectF location = result.getLocation();
              if (location != null && result.getConfidence() >= minimumConfidence) {
                canvas.drawRect(location, paint);

                cropToFrameTransform.mapRect(location);

                result.setLocation(location);
                mappedRecognitions.add(result);
              }
            }

            tracker.trackResults(mappedRecognitions, currTimestamp);
            trackingOverlay.postInvalidate();

            computingDetection = false;

            runOnUiThread(
                new Runnable() {
                  @Override
                  public void run() {
                    showFrameInfo(previewWidth + "x" + previewHeight);
                    showCropInfo(cropCopyBitmap.getWidth() + "x" + cropCopyBitmap.getHeight());
                    showInference(lastProcessingTimeMs + "ms");
                  }
                });
          }
        });
  }

  @Override
  protected int getLayoutId() {
    return R.layout.tfe_od_camera_connection_fragment_tracking;
  }

  @Override
  protected Size getDesiredPreviewFrameSize() {
    return DESIRED_PREVIEW_SIZE;
  }

  // Which detection model to use: by default uses Tensorflow Object Detection API frozen
  // checkpoints.
  private enum DetectorMode {
    TF_OD_API;
  }

  @Override
  protected void setUseNNAPI(final boolean isChecked) {
    runInBackground(
        () -> {
          try {
            detector.setUseNNAPI(isChecked);
          } catch (UnsupportedOperationException e) {
            LOGGER.e(e, "Failed to set \"Use NNAPI\".");
            runOnUiThread(
                () -> {
                  Toast.makeText(this, e.getMessage(), Toast.LENGTH_LONG).show();
                });
          }
        });
  }

  @Override
  protected void setNumThreads(final int numThreads) {
    runInBackground(() -> detector.setNumThreads(numThreads));
  }
}

Other info / Complete Logs

The error log is :
Error getting native address of native library: task_vision_jni
                                                                                                    java.lang.IllegalArgumentException: Error occurred when initializing ObjectDetector: Mobile SSD models are expected to have exactly 4 outputs, found 2
                                                                                                    	at org.tensorflow.lite.task.vision.detector.ObjectDetector.initJniWithByteBuffer(Native Method)
                                                                                                    	at org.tensorflow.lite.task.vision.detector.ObjectDetector.access$100(ObjectDetector.java:88)
                                                                                                    	at org.tensorflow.lite.task.vision.detector.ObjectDetector$3.createHandle(ObjectDetector.java:223)
                                                                                                    	at org.tensorflow.lite.task.core.TaskJniUtils.createHandleFromLibrary(TaskJniUtils.java:91)
                                                                                                    	at org.tensorflow.lite.task.vision.detector.ObjectDetector.createFromBufferAndOptions(ObjectDetector.java:219)
                                                                                                    	at org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel.<init>(TFLiteObjectDetectionAPIModel.java:87)
                                                                                                    	at org.tensorflow.lite.examples.detection.tflite.TFLiteObjectDetectionAPIModel.create(TFLiteObjectDetectionAPIModel.java:81)
                                                                                                    	at org.tensorflow.lite.examples.detection.DetectorActivity.onPreviewSizeChosen(DetectorActivity.java:103)
                                                                                                    	at org.tensorflow.lite.examples.detection.CameraActivity$7.onPreviewSizeChosen(CameraActivity.java:448)
                                                                                                    	at org.tensorflow.lite.examples.detection.CameraConnectionFragment.setUpCameraOutputs(CameraConnectionFragment.java:360)
                                                                                                    	at org.tensorflow.lite.examples.detection.CameraConnectionFragment.openCamera(CameraConnectionFragment.java:365)
                                                                                                    	at org.tensorflow.lite.examples.detection.CameraConnectionFragment.-$$Nest$mopenCamera(Unknown Source:0)
                                                                                                    	at org.tensorflow.lite.examples.detection.CameraConnectionFragment$3.onSurfaceTextureAvailable(CameraConnectionFragment.java:174)
                                                                                                    	at android.view.TextureView.getTextureLayer(TextureView.java:410)
                                                                                                    	at android.view.TextureView.draw(TextureView.java:353)
                                                                                                    	at android.view.View.updateDisplayListIfDirty(View.java:21885)
                                                                                                    	at android.view.View.draw(View.java:22743)
                                                                                                    	at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                    	at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
                                                                                                    	at android.view.View.updateDisplayListIfDirty(View.java:21876)
                                                                                                    	at android.view.View.draw(View.java:22743)
                                                                                                    	at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                    	at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
                                                                                                    	at android.view.View.updateDisplayListIfDirty(View.java:21876)
                                                                                                    	at android.view.View.draw(View.java:22743)
                                                                                                    	at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                    	at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
                                                                                                    	at android.view.View.draw(View.java:23021)
                                                                                                    	at android.view.View.updateDisplayListIfDirty(View.java:21885)
                                                                                                    	at android.view.View.draw(View.java:22743)
                                                                                                    	at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                    	at androidx.coordinatorlayout.widget.CoordinatorLayout.drawChild(CoordinatorLayout.java:1246)
                                                                                                    	at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
                                                                                                    	at android.view.View.draw(View.java:23021)
                                                                                                    	at android.view.View.updateDisplayListIfDirty(View.java:21885)
                                                                                                    	at android.view.View.draw(View.java:22743)
                                                                                                    	at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                    	at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
                                                                                                    	at android.view.View.updateDisplayListIfDirty(View.java:21876)
                                                                                                    	at android.view.View.draw(View.java:22743)
                                                                                                    	at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                    	at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
                                                                                                    	at android.view.View.updateDisplayListIfDirty(View.java:21876)
                                                                                                    	at android.view.View.draw(View.java:22743)
                                                                                                    	at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                    	at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
                                                                                                    	at android.view.View.updateDisplayListIfDirty(View.java:21876)
                                                                                                    	at android.view.View.draw(View.java:22743)
                                                                                                    	at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                    	at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
                                                                                                    	at android.view.View.updateDisplayListIfDirty(View.java:21876)
                                                                                                    	at android.view.View.draw(View.java:22743)
                                                                                                    	at android.view.ViewGroup.drawChild(ViewGroup.java:4542)
                                                                                                    	at android.view.ViewGroup.dispatchDraw(ViewGroup.java:4300)
2024-07-31 00:44:09.336 31317-31317 TaskJniUtils            org...lite.examples.objectdetection  E  	at android.view.View.draw(View.java:23021)
                                                                                                    	at com.android.internal.policy.DecorView.draw(DecorView.java:891)
                                                                                                    	at android.view.View.updateDisplayListIfDirty(View.java:21885)
                                                                                                    	at android.view.ThreadedRenderer.updateViewTreeDisplayList(ThreadedRenderer.java:534)
                                                                                                    	at android.view.ThreadedRenderer.updateRootDisplayList(ThreadedRenderer.java:542)
                                                                                                    	at android.view.ThreadedRenderer.draw(ThreadedRenderer.java:625)
                                                                                                    	at android.view.ViewRootImpl.draw(ViewRootImpl.java:4657)
                                                                                                    	at android.view.ViewRootImpl.performDraw(ViewRootImpl.java:4375)
                                                                                                    	at android.view.ViewRootImpl.performTraversals(ViewRootImpl.java:3486)
                                                                                                    	at android.view.ViewRootImpl.doTraversal(ViewRootImpl.java:2277)
                                                                                                    	at android.view.ViewRootImpl$TraversalRunnable.run(ViewRootImpl.java:9037)
                                                                                                    	at android.view.Choreographer$CallbackRecord.run(Choreographer.java:1142)
                                                                                                    	at android.view.Choreographer.doCallbacks(Choreographer.java:946)
                                                                                                    	at android.view.Choreographer.doFrame(Choreographer.java:875)
                                                                                                    	at android.view.Choreographer$FrameDisplayEventReceiver.run(Choreographer.java:1127)
                                                                                                    	at android.os.Handler.handleCallback(Handler.java:938)
                                                                                                    	at android.os.Handler.dispatchMessage(Handler.java:99)
                                                                                                    	at android.os.Looper.loopOnce(Looper.java:210)
                                                                                                    	at android.os.Looper.loop(Looper.java:299)
                                                                                                    	at android.app.ActivityThread.main(ActivityThread.java:8293)
                                                                                                    	at java.lang.reflect.Method.invoke(Native Method)
                                                                                                    	at com.android.internal.os.RuntimeInit$MethodAndArgsCaller.run(RuntimeInit.java:556)
                                                                                                    	at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:1045)
@libofei2004 libofei2004 added the type:modelmaker Issues related to creation of custom on-device ML solutions label Jul 30, 2024
@google-ml-butler google-ml-butler bot added the type:support General questions label Jul 30, 2024
@kuaashish kuaashish added task:object detection Issues related to Object detection: Track and label objects in images and video. os:linux-non-arm Issues on linux distributions which run on x86-64 architecture. DOES NOT include ARM devices. platform:python MediaPipe Python issues platform:android Issues with Android as Platform labels Jul 31, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
os:linux-non-arm Issues on linux distributions which run on x86-64 architecture. DOES NOT include ARM devices. platform:android Issues with Android as Platform platform:python MediaPipe Python issues task:object detection Issues related to Object detection: Track and label objects in images and video. type:modelmaker Issues related to creation of custom on-device ML solutions type:support General questions
Projects
None yet
Development

No branches or pull requests

2 participants