Skip to content

Commit

Permalink
Implement to non maximum surpression for multi pose estimation (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
tucan9389 committed Apr 18, 2020
1 parent 54b1997 commit 542f78c
Show file tree
Hide file tree
Showing 11 changed files with 568 additions and 53 deletions.
14 changes: 13 additions & 1 deletion PoseEstimation-TFLiteSwift.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
712A7FD12426691B00B043F9 /* PEFMCPMPoseEstimator.swift in Sources */ = {isa = PBXBuildFile; fileRef = 712A7FD02426691B00B043F9 /* PEFMCPMPoseEstimator.swift */; };
712A7FD324266EC700B043F9 /* pefm_hourglass_v1.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 712A7FD224266EC700B043F9 /* pefm_hourglass_v1.tflite */; };
7138DCCF242142FE0048E1D2 /* TFLiteFlatArray.swift in Sources */ = {isa = PBXBuildFile; fileRef = 7138DCCE242142FE0048E1D2 /* TFLiteFlatArray.swift */; };
71DD577F2446D7CF0024C146 /* NonMaximumnonSuppression.swift in Sources */ = {isa = PBXBuildFile; fileRef = 71DD577E2446D7CF0024C146 /* NonMaximumnonSuppression.swift */; };
71E8D9172438BAC10081DD6E /* openpose_ildoonet.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 71E8D9162438BAC10081DD6E /* openpose_ildoonet.tflite */; };
71E8D9192438BAD80081DD6E /* OpenPosePoseEstimator.swift in Sources */ = {isa = PBXBuildFile; fileRef = 71E8D9182438BAD80081DD6E /* OpenPosePoseEstimator.swift */; };
71E8D93B243CC5330081DD6E /* README.md in Resources */ = {isa = PBXBuildFile; fileRef = 71E8D93A243CC5320081DD6E /* README.md */; };
Expand Down Expand Up @@ -60,6 +61,7 @@
712A7FD02426691B00B043F9 /* PEFMCPMPoseEstimator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = PEFMCPMPoseEstimator.swift; sourceTree = "<group>"; };
712A7FD224266EC700B043F9 /* pefm_hourglass_v1.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = pefm_hourglass_v1.tflite; sourceTree = "<group>"; };
7138DCCE242142FE0048E1D2 /* TFLiteFlatArray.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = TFLiteFlatArray.swift; sourceTree = "<group>"; };
71DD577E2446D7CF0024C146 /* NonMaximumnonSuppression.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = NonMaximumnonSuppression.swift; sourceTree = "<group>"; };
71E8D9162438BAC10081DD6E /* openpose_ildoonet.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = openpose_ildoonet.tflite; sourceTree = "<group>"; };
71E8D9182438BAD80081DD6E /* OpenPosePoseEstimator.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = OpenPosePoseEstimator.swift; sourceTree = "<group>"; };
71E8D93A243CC5320081DD6E /* README.md */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = net.daringfireball.markdown; path = README.md; sourceTree = "<group>"; };
Expand Down Expand Up @@ -104,8 +106,8 @@
7105C91B241CE9B6001A4325 /* Main.storyboard */,
712A7FC02424BDD800B043F9 /* StillImageViewController.swift */,
7105C919241CE9B6001A4325 /* LiveImageViewController.swift */,
712A7FC3242504EB00B043F9 /* PoseKeypointsDrawingView.swift */,
7105C92B241D0150001A4325 /* MLModel */,
71DD577D2446D7A40024C146 /* Algorithm */,
7105C92A241D0144001A4325 /* Extension */,
7105C929241D011F001A4325 /* View */,
7105C938241D29C5001A4325 /* Video */,
Expand All @@ -119,6 +121,7 @@
7105C929241D011F001A4325 /* View */ = {
isa = PBXGroup;
children = (
712A7FC3242504EB00B043F9 /* PoseKeypointsDrawingView.swift */,
);
name = View;
sourceTree = "<group>";
Expand Down Expand Up @@ -175,6 +178,14 @@
path = PoseEstimationForMobile;
sourceTree = "<group>";
};
71DD577D2446D7A40024C146 /* Algorithm */ = {
isa = PBXGroup;
children = (
71DD577E2446D7CF0024C146 /* NonMaximumnonSuppression.swift */,
);
name = Algorithm;
sourceTree = "<group>";
};
71E8D9152438BA5B0081DD6E /* OpenPose */ = {
isa = PBXGroup;
children = (
Expand Down Expand Up @@ -312,6 +323,7 @@
7105C92F241D0235001A4325 /* PoseEstimator.swift in Sources */,
712A7FD12426691B00B043F9 /* PEFMCPMPoseEstimator.swift in Sources */,
71E8D9192438BAD80081DD6E /* OpenPosePoseEstimator.swift in Sources */,
71DD577F2446D7CF0024C146 /* NonMaximumnonSuppression.swift in Sources */,
712A7FC4242504EB00B043F9 /* PoseKeypointsDrawingView.swift in Sources */,
7138DCCF242142FE0048E1D2 /* TFLiteFlatArray.swift in Sources */,
7105C93E241E90C2001A4325 /* DataExtension.swift in Sources */,
Expand Down
221 changes: 209 additions & 12 deletions PoseEstimation-TFLiteSwift/Base.lproj/Main.storyboard

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion PoseEstimation-TFLiteSwift/LiveImageViewController.swift
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ extension LiveImageViewController {
let scalingRatio = pixelBuffer.size.width / overlayViewRelativeRect.width
let targetAreaRect = overlayViewRelativeRect.scaled(to: scalingRatio)
let input: PoseEstimationInput = .pixelBuffer(pixelBuffer: pixelBuffer, cropArea: .customAspectFill(rect: targetAreaRect))
let result: Result<PoseEstimationOutput, PoseEstimationError> = poseEstimator.inference(with: input)
let result: Result<PoseEstimationOutput, PoseEstimationError> = poseEstimator.inference(input, with: nil, on: nil)

switch (result) {
case .success(let output):
Expand Down
79 changes: 79 additions & 0 deletions PoseEstimation-TFLiteSwift/NonMaximumnonSuppression.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
//
// NonMaximumnonSuppression.swift
// PoseEstimation-TFLiteSwift
//
// Created by Doyoung Gwak on 2020/04/15.
// Copyright © 2020 Doyoung Gwak. All rights reserved.
//

import Foundation

class NonMaximumnonSuppression {
typealias MaximumPoint = (row: Int, col: Int, val: Float32)

static func process(_ heatmap: TFLiteFlatArray<Float32>, partIndex: Int, width: Int, height: Int) -> [MaximumPoint] {
let filterSize = 3
var lastMaximumColumns: [MaximumPoint?] = (0..<width).map { _ in nil }
var results: [MaximumPoint] = []

for row in (0..<height) {
for col in (0..<width) {
var smallerColumns: [Int] = []
var hasBiggerValueInFilterSize = false
for targetColumn in max(col-filterSize, 0)...min(col+filterSize, width-1) {
if let lastMaximumPoint = lastMaximumColumns[targetColumn] {
if lastMaximumPoint.val < heatmap[heatmap: 0, row, col, partIndex] {
// 작은건 저장
smallerColumns.append(targetColumn)
} else if lastMaximumPoint.val > heatmap[heatmap: 0, row, col, partIndex] {
// 더 큰 경우가 있으면 끝, 버리기
hasBiggerValueInFilterSize = true
break
}
}
}
if !hasBiggerValueInFilterSize {
for smallerColumn in smallerColumns {
lastMaximumColumns[smallerColumn] = nil
}
lastMaximumColumns[col] = (col: col, row: row, val: heatmap[heatmap: 0, row, col, partIndex])
}
// 정리
if let lastMaximumPoint = lastMaximumColumns[col] {
if lastMaximumPoint.row < row-filterSize {
for targetColumn in col...min(col+filterSize*2, width-1) {
if let compareLastMaximumPoint = lastMaximumColumns[targetColumn],
lastMaximumPoint.row == compareLastMaximumPoint.row,
lastMaximumPoint.col == compareLastMaximumPoint.col {
lastMaximumColumns[targetColumn] = nil
}
}
results.append((row: lastMaximumPoint.row,
col: lastMaximumPoint.col,
val: lastMaximumPoint.val))
}
}
}
}

// 마지막 남은 것 정리
for (offset, lastMaximumPoint) in lastMaximumColumns.enumerated() {
guard let lastMaximumPoint = lastMaximumPoint else { continue }
for targetColumn in offset...min(offset+filterSize*2, width-1) {
if let compareLastMaximumPoint = lastMaximumColumns[targetColumn],
lastMaximumPoint.row == compareLastMaximumPoint.row,
lastMaximumPoint.col == compareLastMaximumPoint.col {
lastMaximumColumns[targetColumn] = nil
}
}
results.append((row: lastMaximumPoint.row,
col: lastMaximumPoint.col,
val: lastMaximumPoint.val))
}

return results
}



}
93 changes: 81 additions & 12 deletions PoseEstimation-TFLiteSwift/OpenPose/OpenPosePoseEstimator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,41 @@ class OpenPosePoseEstimator: PoseEstimator {
return imageInterpreter
}()

func inference(with input: PoseEstimationInput) -> OpenPoseResult {
var modelOutput: [TFLiteFlatArray<Float32>]?

func inference(_ input: PoseEstimationInput, with threshold: Float?, on partIndex: Int?) -> OpenPoseResult {

// initialize
modelOutput = nil

// preprocss
guard let inputData = imageInterpreter.preprocess(with: input)
else { return .failure(.failToCreateInputData) }

// inference
guard let outputs = imageInterpreter.inference(with: inputData)
modelOutput = imageInterpreter.inference(with: inputData)
guard let outputs = modelOutput
else { return .failure(.failToInference) }

// postprocess
let result = postprocess(with: outputs)
let result = OpenPoseResult.success(postprocess(outputs, with: threshold, on: partIndex))

return result
}

private func postprocess(with outputs: [TFLiteFlatArray<Float32>]) -> OpenPoseResult {
return .success(PoseEstimationOutput(outputs: outputs))
private func postprocess(_ outputs: [TFLiteFlatArray<Float32>], with threshold: Float?=nil, on partIndex: Int?=nil) -> PoseEstimationOutput {
// if you want to postprocess with only single person, use .singlePerson on humanType
// in .multiPerson, if the bodyPart is nil, parse all part
return PoseEstimationOutput(outputs: outputs, humanType: .multiPerson(threshold: threshold, bodyPart: partIndex))
}

func postprocessOnLastOutput(with threshold: Float?=nil, on partIndex: Int?=nil) -> PoseEstimationOutput? {
guard let outputs = modelOutput else { return nil }
return postprocess(outputs, with: threshold, on: partIndex)
}

var partNames: [String] {
return Output.BodyPart.allCases.map { $0.rawValue }
}
}

Expand All @@ -60,7 +80,7 @@ private extension OpenPosePoseEstimator {
static let count = BodyPart.allCases.count * 2 // 38
}
enum BodyPart: String, CaseIterable {
case NOSE = "nose" // 0
case NOSE = "Nose" // 0
case NECK = "Neck" // 1
case RIGHT_SHOULDER = "RShoulder" // 2
case RIGHT_ELBOW = "RElbow" // 3
Expand Down Expand Up @@ -109,15 +129,24 @@ private extension OpenPosePoseEstimator {
}

private extension PoseEstimationOutput {
init(outputs: [TFLiteFlatArray<Float32>]) {
let keypoints = convertToKeypoints(from: outputs)
let lines = makeLines(with: keypoints)

humans = [Human(keypoints: keypoints, lines: lines)]
enum HumanType {
case singlePerson
case multiPerson(threshold: Float?, bodyPart: Int?)
}

init(outputs: [TFLiteFlatArray<Float32>], humanType: HumanType = .singlePerson) {
switch humanType {
case .singlePerson:
let keypoints = convertToKeypoints(from: outputs)
let lines = makeLines(with: keypoints)
humans = [Human(keypoints: keypoints, lines: lines)]
case .multiPerson(threshold: let threshold, let bodyPart):
humans = parseMultiHuman(from: outputs, on: bodyPart, with: threshold)
}
}

func convertToKeypoints(from outputs: [TFLiteFlatArray<Float32>]) -> [Keypoint] {
let output = outputs[0]
let output = outputs[0] // openpose_ildoonet.tflite only use the first output

// get (col, row)s from heatmaps
let keypointIndexInfos: [(row: Int, col: Int, val: Float32)] = (0..<OpenPosePoseEstimator.Output.ConfidenceMap.count).map { heatmapIndex in
Expand Down Expand Up @@ -156,6 +185,32 @@ private extension PoseEstimationOutput {
return (from: fromKeypoint, to: toKeypoint)
}
}

func parseMultiHuman(from outputs: [TFLiteFlatArray<Float32>], on partIndex: Int?, with threshold: Float?) -> [Human] {
guard let partIndex = partIndex else { return [] }

let output = outputs[0] // openpose_ildoonet.tflite only use the first output

// 1. nms
let keypointIndexes = output.keypoints(partIndex: partIndex, filterSize: 3, threshold: threshold)
let kps: [Keypoint] = keypointIndexes.map { keypointInfo in
let x = (CGFloat(keypointInfo.col) + 0.5) / CGFloat(OpenPosePoseEstimator.Output.ConfidenceMap.width)
let y = (CGFloat(keypointInfo.row) + 0.5) / CGFloat(OpenPosePoseEstimator.Output.ConfidenceMap.height)
let score = Float(keypointInfo.val)
return Keypoint(position: CGPoint(x: x, y: y), score: score)
}

return kps.map { keypoint in
let keypoints: [Keypoint] = OpenPosePoseEstimator.Output.BodyPart.allCases.enumerated().map { offset, _ in
return (offset == partIndex) ? keypoint : Keypoint(position: .zero, score: -100)
}
return Human(keypoints: keypoints, lines: [])
}

// 2.

// return []
}
}

extension TFLiteFlatArray where Element == Float32 {
Expand All @@ -173,3 +228,17 @@ extension TFLiteFlatArray where Element == Float32 {
}
}
}

// NMS
extension TFLiteFlatArray where Element == Float32 {
func keypoints(partIndex: Int, filterSize: Int, threshold: Float?) -> [(row: Int, col: Int, val: Element)] {
let hWidth = OpenPosePoseEstimator.Output.ConfidenceMap.width
let hHeight = OpenPosePoseEstimator.Output.ConfidenceMap.height
let results = NonMaximumnonSuppression.process(self, partIndex: partIndex, width: hWidth, height: hHeight)
if let threshold = threshold {
return results.filter { $0.val > threshold }
} else {
return results
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,38 @@ class PEFMCPMPoseEstimator: PoseEstimator {
return imageInterpreter
}()

func inference(with input: PoseEstimationInput) -> PoseNetResult {
var modelOutput: [TFLiteFlatArray<Float32>]?

func inference(_ input: PoseEstimationInput, with threshold: Float?, on partIndex: Int?) -> PoseNetResult {

// initialize
modelOutput = nil

// preprocss
guard let inputData = imageInterpreter.preprocess(with: input)
else { return .failure(.failToCreateInputData) }

// inference
guard let outputs = imageInterpreter.inference(with: inputData)
else { return .failure(.failToInference) }

// postprocess
let result = postprocess(with: outputs)
let result = PoseNetResult.success(postprocess(with: outputs))

return result
}

private func postprocess(with outputs: [TFLiteFlatArray<Float32>]) -> PoseEstimationOutput {
return PoseEstimationOutput(outputs: outputs)
}

func postprocessOnLastOutput(with threshold: Float?=nil, on partIndex: Int?=nil) -> PoseEstimationOutput? {
guard let outputs = modelOutput else { return nil }
return postprocess(with: outputs)
}

private func postprocess(with outputs: [TFLiteFlatArray<Float32>]) -> PoseNetResult {
return .success(PoseEstimationOutput(outputs: outputs))
var partNames: [String] {
return Output.BodyPart.allCases.map { $0.rawValue }
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,38 @@ class PEFMHourglassPoseEstimator: PoseEstimator {
return imageInterpreter
}()

func inference(with input: PoseEstimationInput) -> PoseNetResult {
var modelOutput: [TFLiteFlatArray<Float32>]?

func inference(_ input: PoseEstimationInput, with threshold: Float?, on partIndex: Int?) -> PoseNetResult {

// initialize
modelOutput = nil

// preprocss
guard let inputData = imageInterpreter.preprocess(with: input)
else { return .failure(.failToCreateInputData) }

// inference
guard let outputs = imageInterpreter.inference(with: inputData)
else { return .failure(.failToInference) }

// postprocess
let result = postprocess(with: outputs)
let result = PoseNetResult.success(postprocess(with: outputs))

return result
}

private func postprocess(with outputs: [TFLiteFlatArray<Float32>]) -> PoseNetResult {
return .success(PoseEstimationOutput(outputs: outputs))
private func postprocess(with outputs: [TFLiteFlatArray<Float32>]) -> PoseEstimationOutput {
return PoseEstimationOutput(outputs: outputs)
}

func postprocessOnLastOutput(with threshold: Float?=nil, on partIndex: Int?=nil) -> PoseEstimationOutput? {
guard let outputs = modelOutput else { return nil }
return postprocess(with: outputs)
}

var partNames: [String] {
return Output.BodyPart.allCases.map { $0.rawValue }
}
}

Expand Down
4 changes: 3 additions & 1 deletion PoseEstimation-TFLiteSwift/PoseEstimator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,7 @@ enum PoseEstimationError: Error {
}

protocol PoseEstimator {
func inference(with input: PoseEstimationInput) -> Result<PoseEstimationOutput, PoseEstimationError>
func inference(_ input: PoseEstimationInput, with threshold: Float?, on partIndex: Int?) -> Result<PoseEstimationOutput, PoseEstimationError>
func postprocessOnLastOutput(with threshold: Float?, on partIndex: Int?) -> PoseEstimationOutput?
var partNames: [String] { get }
}
2 changes: 1 addition & 1 deletion PoseEstimation-TFLiteSwift/PoseKeypointsDrawingView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ private extension PoseKeypointsDrawingView {
static let radius: CGFloat = 5
static let borderWidth: CGFloat = 2
static let borderColor: UIColor = UIColor.red
static let fillColor: UIColor = UIColor(red: 0.6, green: 0.2, blue: 0.2, alpha: 1)
static let fillColor: UIColor = UIColor.green
}
enum Line {
static let width: CGFloat = 2
Expand Down
Loading

0 comments on commit 542f78c

Please sign in to comment.