This repository has been archived by the owner on Sep 6, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 5
/
signature.js
95 lines (87 loc) · 3.45 KB
/
signature.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env -S node --no-deprecation --trace-warnings
const fs = require('fs');
const path = require('path');
// @ts-ignore
const log = require('@vladmandic/pilogger');
// eslint-disable-next-line node/no-unpublished-require, import/no-extraneous-dependencies
const tf = require('@tensorflow/tfjs-node');
async function analyzeGraph(modelPath) {
const model = await tf.loadGraphModel(`file://${modelPath}`);
log.info('graph model:', path.resolve(modelPath));
log.info('size:', tf.engine().memory());
const inputs = [];
if (model.modelSignature['inputs']) {
log.info('model inputs based on signature');
for (const [key, val] of Object.entries(model.modelSignature['inputs'])) {
const shape = val.tensorShape.dim.map((a) => parseInt(a.size));
inputs.push({ name: key, dtype: val.dtype, shape });
}
} else if (model.executor.graph['inputs']) {
log.info('model inputs based on executor');
for (const t of model.executor.graph['inputs']) {
inputs.push({ name: t.name, dtype: t.attrParams.dtype.value, shape: t.attrParams.shape.value });
}
// const shape = input.attrParam.map((a) => parseInt(a.size));
} else {
log.warn('model inputs: cannot determine');
}
const outputs = [];
let i = 0;
if (Object.values(model.modelSignature['outputs'])[0].dtype) {
log.info('model outputs based on signature');
for (const [key, val] of Object.entries(model.modelSignature['outputs'])) {
const shape = val.tensorShape?.dim.map((a) => parseInt(a.size));
outputs.push({ id: i++, name: key, dytpe: val.dtype, shape });
}
} else if (model.executor.graph['outputs']) {
log.info('model outputs based on executor');
for (const t of model.executor.graph['outputs']) {
outputs.push({ id: i++, name: t.name, dtype: t.attrParams.dtype?.value || t.rawAttrs.T.type, shape: t.attrParams.shape?.value });
}
} else {
log.warn('model outputs: cannot determine');
}
log.data('inputs:', inputs);
log.data('outputs:', outputs);
}
async function analyzeSaved(modelPath) {
const meta = await tf.node.getMetaGraphsFromSavedModel(modelPath);
log.info('saved model:', path.resolve(modelPath));
const sign = Object.values(meta[0].signatureDefs)[0];
log.data('tags:', meta[0].tags);
log.data('signature:', Object.keys(meta[0].signatureDefs));
const inputs = Object.values(sign.inputs)[0];
// @ts-ignore
const inputShape = inputs.shape?.map((a) => a.array[0]);
log.data('inputs:', { name: inputs.name, dtype: inputs.dtype, shape: inputShape });
const outputs = [];
let i = 0;
for (const [key, val] of Object.entries(sign.outputs)) {
// @ts-ignore
const shape = val.shape?.map((a) => a.array[0]);
outputs.push({ id: i++, name: key, dytpe: val.dtype, shape });
}
log.data('outputs:', outputs);
}
async function main() {
log.header();
const param = process.argv[2];
if (process.argv.length !== 3) {
log.error('path required');
process.exit(0);
} else if (!fs.existsSync(param)) {
log.error(`path does not exist: ${param}`);
process.exit(0);
}
const stat = fs.statSync(param);
log.data('created on:', stat.birthtime);
if (stat.isFile()) {
if (param.endsWith('.json')) analyzeGraph(param);
// if (param.endsWith('.pb')) analyzeSaved(param);
}
if (stat.isDirectory()) {
if (fs.existsSync(path.join(param, '/saved_model.pb'))) analyzeSaved(param);
if (fs.existsSync(path.join(param, '/model.json'))) analyzeGraph(path.join(param, '/model.json'));
}
}
main();