Skip to content

Commit

Permalink
add detect image
Browse files Browse the repository at this point in the history
  • Loading branch information
webees committed Sep 16, 2023
1 parent 5c91f20 commit 104988a
Show file tree
Hide file tree
Showing 9 changed files with 353 additions and 11 deletions.
2 changes: 2 additions & 0 deletions components.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ declare module 'vue' {
RouterLink: typeof import('vue-router')['RouterLink']
RouterView: typeof import('vue-router')['RouterView']
TabBar: typeof import('./src/components/TabBar.vue')['default']
VanButton: typeof import('vant/es')['Button']
VanTabbar: typeof import('vant/es')['Tabbar']
VanTabbarItem: typeof import('vant/es')['TabbarItem']
VanToast: typeof import('vant/es')['Toast']
VanUploader: typeof import('vant/es')['Uploader']
}
}
7 changes: 4 additions & 3 deletions src/i18n/zhCN.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
export default {
zhCN: {
_tabbar: {
image: '图像',
image: '图片',
video: '视频',
webcam: '摄像'
},
image: '图像',
image: '图片',
video: '视频',
webcam: '摄像'
webcam: '摄像',
'Open Image': '打开图片'
}
}
6 changes: 3 additions & 3 deletions src/stores/yolo.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ import { defineStore } from 'pinia'
import type { GraphModel, io } from '@tensorflow/tfjs'

export default defineStore('yolo', () => {
const name = ref('yolov8n')
const net = ref<GraphModel<string | io.IOHandler> | null>()
const version = ref('yolov8n')
const model = ref<GraphModel<string | io.IOHandler>>()
const inputShape = ref([1, 0, 0, 3])
const loading = ref(0)

return { name, net, inputShape, loading }
return { version, model, inputShape, loading }
})
19 changes: 18 additions & 1 deletion src/styles/theme.less
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,32 @@ body {
width: 50%;
}

.w-full {
width: 100%;
}

/* width */
.h-full {
height: 100%;
}

/* display */
.flex {
display: flex
}

.space-around {
.justify-center {
justify-content: center;
}

.justify-around {
justify-content: space-around;
}

.items-center {
align-items: center;
}

/* position */
.absolute {
position: absolute;
Expand Down
82 changes: 82 additions & 0 deletions src/utils/labels.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
[
"person",
"bicycle",
"car",
"motorcycle",
"airplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"couch",
"potted plant",
"bed",
"dining table",
"toilet",
"tv",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush"
]
107 changes: 107 additions & 0 deletions src/utils/renderBox.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import labels from './labels.json'

/**
* Render prediction boxes
* @param {HTMLCanvasElement} canvasRef canvas tag reference
* @param {Array} boxes_data boxes array
* @param {Array} scores_data scores array
* @param {Array} classes_data class array
* @param {Array[Number]} ratios boxes ratio [xRatio, yRatio]
*/
export function renderBoxes(
canvasRef: HTMLCanvasElement,
boxes_data: Float32Array | Int32Array | Uint8Array,
scores_data: Float32Array | Int32Array | Uint8Array,
classes_data: Float32Array | Int32Array | Uint8Array,
ratios: any[]
) {
const ctx = canvasRef.getContext('2d')
if (ctx) {
ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height) // clean canvas

const colors = new Colors()

// font configs
const font = `${Math.max(Math.round(Math.max(ctx.canvas.width, ctx.canvas.height) / 40), 14)}px Arial`
ctx.font = font
ctx.textBaseline = 'top'

for (let i = 0; i < scores_data.length; ++i) {
// filter based on class threshold
const klass = labels[classes_data[i]]
const color = colors.get(classes_data[i])
const score = (scores_data[i] * 100).toFixed(1)

let [y1, x1, y2, x2] = boxes_data.slice(i * 4, (i + 1) * 4)
x1 *= ratios[0]
x2 *= ratios[0]
y1 *= ratios[1]
y2 *= ratios[1]
const width = x2 - x1
const height = y2 - y1

// draw box.
ctx.fillStyle = Colors.hexToRgba(color, 0.2)!
ctx.fillRect(x1, y1, width, height)

// draw border box.
ctx.strokeStyle = color
ctx.lineWidth = Math.max(Math.min(ctx.canvas.width, ctx.canvas.height) / 200, 2.5)
ctx.strokeRect(x1, y1, width, height)

// Draw the label background.
ctx.fillStyle = color
const textWidth = ctx.measureText(`${klass} - ${score}%`).width
const textHeight = parseInt(font, 10) // base 10
const yText = y1 - (textHeight + ctx.lineWidth)
ctx.fillRect(
x1 - 1,
yText < 0 ? 0 : yText, // handle overflow label box
textWidth + ctx.lineWidth,
textHeight + ctx.lineWidth
)

// Draw labels
ctx.fillStyle = '#ffffff'
ctx.fillText(`${klass} - ${score}%`, x1 - 1, yText < 0 ? 0 : yText)
}
}
}

class Colors {
palette: string[]
n: number
// ultralytics color palette https://ultralytics.com/
constructor() {
this.palette = [
'#FF3838',
'#FF9D97',
'#FF701F',
'#FFB21D',
'#CFD231',
'#48F90A',
'#92CC17',
'#3DDB86',
'#1A9334',
'#00D4BB',
'#2C99A8',
'#00C2FF',
'#344593',
'#6473FF',
'#0018EC',
'#8438FF',
'#520085',
'#CB38FF',
'#FF95C8',
'#FF37C7'
]
this.n = this.palette.length
}

get = (i: number) => this.palette[Math.floor(i) % this.n]

static hexToRgba = (hex: string, alpha: number) => {
const result = /^#?([a-f\d]{2})([a-f\d]{2})([a-f\d]{2})$/i.exec(hex)
return result ? `rgba(${[parseInt(result[1], 16), parseInt(result[2], 16), parseInt(result[3], 16)].join(', ')}, ${alpha})` : null
}
}
110 changes: 108 additions & 2 deletions src/utils/tf.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import type { GraphModel, io, Rank, Tensor, Tensor1D, Tensor2D, Tensor3D } from '@tensorflow/tfjs'
import * as tf from '@tensorflow/tfjs'
import '@tensorflow/tfjs-backend-webgl'
import { yolo } from '@/vue-pinia'

import { renderBoxes } from '@/utils/renderBox'
import labels from '@/utils/labels.json'

const numClass = labels.length

export function loadModel() {
tf.ready().then(async () => {
const model = await tf.loadGraphModel(`${process.env.VUE_APP_PUBLIC_PATH}${yolo().name}_web_model/model.json`, {
const model = await tf.loadGraphModel(`${process.env.VUE_APP_PUBLIC_PATH}${yolo().version}_web_model/model.json`, {
onProgress: progress => {
console.log('tf.loadGraphModel', progress)
yolo().loading = progress
Expand All @@ -16,10 +22,110 @@ export function loadModel() {
const dummyInput = tf.ones(model.inputs[0].shape)
const warmupResults = model.execute(dummyInput)

yolo().net = model
yolo().model = model
yolo().inputShape = model.inputs[0].shape

tf.dispose([warmupResults, dummyInput]) // cleanup memory
}
})
}

/**
* Preprocess image / frame before forwarded into the model
* @param {HTMLVideoElement|HTMLImageElement} source
* @param {Number} modelWidth
* @param {Number} modelHeight
* @returns input tensor, xRatio and yRatio
*/
function preprocess(source: HTMLVideoElement | HTMLImageElement, modelWidth: number, modelHeight: number): [Tensor, number, number] {
// ratios for boxes
let xRatio = 0
let yRatio = 0

const input = tf.tidy(() => {
const img = tf.browser.fromPixels(source)

// padding image to square => [n, m] to [n, n], n > m
const [h, w] = img.shape.slice(0, 2) // get source width and height
console.log('image', h, w)

const maxSize = Math.max(w, h) // get max size
const imgPadded = img.pad([
[0, maxSize - h], // padding y [bottom only]
[0, maxSize - w], // padding x [right only]
[0, 0]
]) as Tensor3D

xRatio = maxSize / w // update xRatio
yRatio = maxSize / h // update yRatio

return tf.image
.resizeBilinear(imgPadded, [modelWidth, modelHeight]) // resize frame
.div(255.0) // normalize
.expandDims(0) // add batch
})

return [input, xRatio, yRatio]
}

/**
* Function run inference and do detection from source.
* @param {tf.GraphModel} model loaded YOLOv8 tensorflow.js model
* @param {number[]} inputShape
* @param {HTMLImageElement|HTMLVideoElement} source
* @param {HTMLCanvasElement} canvasRef canvas reference
* @param {VoidFunction} callback function to run after detection process
*/
export async function detect(
model: GraphModel<string | io.IOHandler>,
inputShape: number[],
source: HTMLImageElement | HTMLVideoElement,
canvasRef: HTMLCanvasElement,
callback: () => void
) {
tf.engine().startScope() // start scoping tf engine
const [modelWidth, modelHeight] = inputShape.slice(1, 3) // get model width and height
console.log('shape', modelWidth, modelHeight)

const [input, xRatio, yRatio] = preprocess(source, modelWidth, modelHeight) // preprocess image
console.log('ratio', xRatio, yRatio)

const res = toRaw(model).execute(input) as Tensor<Rank> // inference model
const transRes = res.transpose([0, 2, 1]) // transpose result [b, det, n] => [b, n, det]

const boxes = tf.tidy(() => {
const w = transRes.slice([0, 0, 2], [-1, -1, 1]) // get width
const h = transRes.slice([0, 0, 3], [-1, -1, 1]) // get height
const x1 = tf.sub(transRes.slice([0, 0, 0], [-1, -1, 1]), tf.div(w, 2)) // x1
const y1 = tf.sub(transRes.slice([0, 0, 1], [-1, -1, 1]), tf.div(h, 2)) // y1
return tf
.concat(
[
y1,
x1,
tf.add(y1, h), // y2
tf.add(x1, w) // x2
],
2
)
.squeeze()
}) as Tensor2D // process boxes [y1, x1, y2, x2]

const [scores, classes] = tf.tidy(() => {
const rawScores = transRes.slice([0, 0, 4], [-1, -1, numClass]).squeeze() // #6 only squeeze axis 0 to handle only 1 class models
return [rawScores.max(1), rawScores.argMax(1)]
}) as [Tensor1D, Tensor2D] // get max scores and classes index

const nms = await tf.image.nonMaxSuppressionAsync(boxes, scores, 500, 0.45, 0.2) // NMS to filter boxes

const boxes_data = boxes.gather(nms, 0).dataSync() // indexing boxes by nms index
const scores_data = scores.gather(nms, 0).dataSync() // indexing scores by nms index
const classes_data = classes.gather(nms, 0).dataSync() // indexing classes by nms index

renderBoxes(canvasRef, boxes_data, scores_data, classes_data, [xRatio, yRatio]) // render boxes
tf.dispose([res, transRes, boxes, scores, classes, nms]) // clear memory

callback()

tf.engine().endScope() // end of scoping
}
Loading

0 comments on commit 104988a

Please sign in to comment.