Skip to content

Virtual Background with WebRTC in Android

Kai Dao edited this page Mar 11, 2024 · 4 revisions

Overview

Virtual backgrounds are becoming necessary nowadays in the video conferencing world. It allows us to replace our natural background with an image or a video. We can also upload our custom images in the background.

👉 By end of this wiki, you can expect the virtual background feature to look like this:

Virtual Background on Android (Mediapipe for Image segment)

Dependencies

Add the dependencies for the Mediapipe Android libraries to the module's app-level gradle file, which is usually app/build.gradle:

implementation 'com.google.mediapipe:tasks-vision:0.10.11'

Common WebRTC terms you should know

  1. VideoFrame: It contains the buffer of the frame captured by the camera device in I420 format.
  2. VideoSink: It is used to send the frame back to WebRTC native source.
  3. VideoSource: It reads the camera device, produces VideoFrames, and delivers them to VideoSinks.
  4. VideoProcessor: It is an interface provided by WebRTC to update videoFrames produced by videoSource .
  5. MediaStream: It is an API related to WebRTC which provides support for streaming audio and video data. It consists of zero or more MediaStreamTrack objects, representing various audio or video tracks

Implement in code

Getting the VideoFrame from WebRTC

videoSource?.setVideoProcessor(object : VideoProcessor {
            override fun onCapturerStarted(success: Boolean) {
                // Handle video capture start event
            }

            override fun onCapturerStopped() {
                // Handle video capture stop event
            }

            @SuppressLint("LongLogTag")
            @RequiresApi(Build.VERSION_CODES.N)
            override fun onFrameCaptured(frame: VideoFrame) {
                if (sink != null) {
                    val currentTime = System.currentTimeMillis()
                    val elapsedSinceLastProcessedFrame = currentTime - lastProcessedFrameTime

                    // Check if the elapsed time since the last processed frame is greater than the target interval
                    if (elapsedSinceLastProcessedFrame >= targetFrameInterval) {
                        // Process the current frame
                        lastProcessedFrameTime = currentTime

                        // Otherwise, perform segmentation on the captured frame and replace the background
                        val inputFrameBitmap: Bitmap? = videoFrameToBitmap(frame)
                        if (inputFrameBitmap != null) {
                            runBlocking {
                                if (backgroundBitmap != null) {
                                    // Run segmentation in the background
                                    val resizeBitmap = virtualBackground.resizeBitmapKeepAspectRatio(inputFrameBitmap, 512)

                                    // Segment the input bitmap using the ImageSegmentationHelper
                                    val frameTimeMs: Long = SystemClock.uptimeMillis()
                                    bitmapMap[frameTimeMs] = CacheFrame(originalBitmap = resizeBitmap, originalFrame = frame)

                                    imageSegmentationHelper?.segmentLiveStreamFrame(resizeBitmap, frameTimeMs)
                                } else {
                                    val cacheFrame = CacheFrame(originalBitmap = inputFrameBitmap, originalFrame = frame)
                                    bitmapMap[lastProcessedFrameTime] = cacheFrame
                                    emitBitmapOnFrame(inputFrameBitmap, cacheFrame)
                                    bitmapMap.remove(lastProcessedFrameTime)
                                }
                            }
                        } else {
                            Log.d(tag, "Convert video frame to bitmap failure")
                        }
                    }
                }
            }

            override fun setSink(videoSink: VideoSink?) {
                // Store the VideoSink to send the processed frame back to WebRTC
                // The sink will be used after segmentation processing
                sink = videoSink
            }
        })

Initialize Mediapipe Image Segmenter

this.imageSegmentationHelper = ImageSegmenterHelper(
            context = context,
            runningMode = RunningMode.LIVE_STREAM,
            imageSegmenterListener = object : ImageSegmenterHelper.SegmenterListener {
                override fun onError(error: String, errorCode: Int) {
                    // no-op
                }

                override fun onResults(resultBundle: ImageSegmenterHelper.ResultBundle) {
                    // Process the results after Mediapipe separates the background
                }
            })

Handle Person Mask from Mediapipe

override fun onResults(resultBundle: ImageSegmenterHelper.ResultBundle) {
                    val timestampMS = resultBundle.frameTime
                    val cacheFrame: CacheFrame = bitmapMap[timestampMS] ?: return

                    val maskFloat = resultBundle.results
                    val maskWidth = resultBundle.width
                    val maskHeight = resultBundle.height

                    val bitmap = cacheFrame.originalBitmap
                    val mask = virtualBackground.convertFloatBufferToByteBuffer(maskFloat)

                    // Convert the buffer to an array of colors
                    val colors = virtualBackground.maskColorsFromByteBuffer(
                        mask,
                        maskWidth,
                        maskHeight,
                        bitmap,
                        expectConfidence,
                    )

                    // Create the segmented bitmap from the color array
                    val segmentedBitmap = virtualBackground.createBitmapFromColors(colors, bitmap.width, bitmap.height)

                    if (backgroundBitmap == null) {
                        // If the background bitmap is null, return without further processing
                        return
                    }

                    // Draw the segmented bitmap on top of the background for human segments
                    val outputBitmap = virtualBackground.drawSegmentedBackground(segmentedBitmap, backgroundBitmap, cacheFrame.originalFrame.rotation)

                    if (outputBitmap != null) {
                        emitBitmapOnFrame(outputBitmap, cacheFrame)
                    }

                    bitmapMap.remove(timestampMS)
                }

Draw segmented and background on canvas

/**
     * Draw the segmentedBitmap on top of the backgroundBitmap with the background rotated by the specified angle (in degrees)
     * and both background and segmentedBitmap flipped vertically to match the same orientation.
     *
     * @param segmentedBitmap The bitmap representing the segmented foreground with transparency.
     * @param backgroundBitmap The bitmap representing the background image to be used as the base.
     * @param rotationAngle The angle in degrees to rotate both the backgroundBitmap and segmentedBitmap.
     * @return The resulting bitmap with the segmented foreground overlaid on the rotated and vertically flipped background.
     *         Returns null if either of the input bitmaps is null.
     */
    fun drawSegmentedBackground(
        segmentedBitmap: Bitmap?,
        backgroundBitmap: Bitmap?,
        rotationAngle: Int?
    ): Bitmap? {
        if (segmentedBitmap == null || backgroundBitmap == null) {
            return null
        }

        val isHorizontalFrame = rotationAngle == 0 || rotationAngle == 180

        val outputBitmap = Bitmap.createBitmap(
            segmentedBitmap.width,
            segmentedBitmap.height,
            Bitmap.Config.ARGB_8888
        )

        val canvas = Canvas(outputBitmap)

        val paint = Paint(Paint.ANTI_ALIAS_FLAG)

        val matrix = Matrix()
        matrix.postRotate((rotationAngle?.toFloat() ?: 0f) - 180)

        if (isHorizontalFrame) {
            val scaleFitContain = Math.min(
                segmentedBitmap.width.toFloat() / backgroundBitmap.width,
                segmentedBitmap.height.toFloat() / backgroundBitmap.height
            )

            val scaledWidthFitContain = (backgroundBitmap.width * scaleFitContain).toInt()
            val scaledHeightFitContain = (backgroundBitmap.height * scaleFitContain).toInt()

            val rotatedBackgroundBitmap = Bitmap.createBitmap(
                backgroundBitmap,
                0,
                0,
                backgroundBitmap.width,
                backgroundBitmap.height,
                matrix,
                true
            )

            val backgroundRect = Rect(
                (segmentedBitmap.width - scaledWidthFitContain) / 2,
                (segmentedBitmap.height - scaledHeightFitContain) / 2,
                (segmentedBitmap.width + scaledWidthFitContain) / 2,
                (segmentedBitmap.height + scaledHeightFitContain) / 2
            )

            canvas.drawBitmap(
                rotatedBackgroundBitmap,
                null,
                backgroundRect,
                paint
            )
        } else {
            val newBackgroundWidth = Math.min(segmentedBitmap.width.toFloat(), segmentedBitmap.height.toFloat()).toInt()
            val scaleFactor = (newBackgroundWidth.toFloat() / backgroundBitmap.width.toFloat())
            val newBackgroundHeight = (backgroundBitmap.height * scaleFactor).toInt()

            val scaledBackground = scaleBitmap(backgroundBitmap, newBackgroundWidth, newBackgroundHeight)

            val rotatedBackgroundBitmap = Bitmap.createBitmap(
                scaledBackground,
                0,
                0,
                scaledBackground.width,
                scaledBackground.height,
                matrix,
                true
            )

            canvas.drawBitmap(
                rotatedBackgroundBitmap,
                0f,
                0f,
                paint
            )
        }

        canvas.drawBitmap(segmentedBitmap, 0f, 0f, paint)

        return outputBitmap
    }

Task benchmarks

Here's the task benchmarks for the whole pipeline based on the above pre-trained models. The latency result is the average latency on Pixel 6 using CPU / GPU.

Model Name CPU Latency GPU Latency
SelfieSegmenter (square) 33.46ms 35.15ms

Reference