forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathgraph.ts
807 lines (713 loc) · 27.3 KB
/
graph.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
import {Attribute} from './attribute';
import {onnxruntime} from './ort-schema/flatbuffers/ort-generated';
import {onnx} from './ort-schema/protobuf/onnx';
import {Tensor} from './tensor';
import {LongUtil, MAX_CLIP, MIN_CLIP, ProtoUtil} from './util';
import ortFbs = onnxruntime.experimental.fbs;
export declare namespace Graph {
export interface Shape {
readonly dims: readonly number[];
}
export interface ValueType {
readonly tensorType: Tensor.DataType;
readonly shape: Shape;
}
export interface Value {
// the tensor data. empty for non-initialized inputs
readonly tensor?: Tensor;
// index to the Node where the value comes from. -1 for initializer.
readonly from: number;
// indices to the Nodes where the values go to.
readonly to: readonly number[];
// value type specification. empty for non-input values.
readonly type?: ValueType;
}
export interface Node {
// name of the node
readonly name: string;
// the operator type
readonly opType: string;
// indices to the Values where the inputs come from.
readonly inputs: readonly number[];
// indices to the Values where the outpus go to.
readonly outputs: readonly number[];
// the attributes that used by the operator
readonly attributes: Attribute;
}
/**
* a Transformer is an instance that allows all possible transformation operations that applied to a graph
*/
export interface Transformer {
removeAllIdentityNodes(): void;
removeAllDropoutNodes(): void;
fuseConvActivationNodes(): void;
// TODO: add generic functions to manipulate the graph
}
// an initializer can use transformer to transform the graph
export interface Initializer {
transformGraph(transformer: Transformer): void;
}
}
// eslint-disable-next-line @typescript-eslint/no-redeclare
export interface Graph {
getInputIndices(): readonly number[];
getInputNames(): readonly string[];
getOutputIndices(): readonly number[];
getOutputNames(): readonly string[];
getValues(): readonly Graph.Value[];
getNodes(): readonly Graph.Node[];
}
// eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-redeclare
export const Graph = {
/**
* construct a graph from a graph protobuf type
*/
from: (graphProto: onnx.IGraphProto|ortFbs.Graph, initializer?: Graph.Initializer) =>
new GraphImpl(graphProto, initializer),
};
class Value implements Graph.Value {
constructor(valueInfo?: onnx.IValueInfoProto) {
this._from = undefined;
this._to = [];
this.tensor = undefined;
this.type = undefined;
if (valueInfo) {
this.type = ProtoUtil.tensorValueTypeFromProto(valueInfo.type!.tensorType!);
}
}
_from?: number; // -1 represent from initializer
get from() {
return this._from!;
}
_to: number[];
get to() {
return this._to;
}
type?: Graph.ValueType;
tensor?: Tensor;
}
class Node implements Graph.Node {
constructor(_nodeProto: onnx.INodeProto|ortFbs.Node, name?: string) {
if (_nodeProto instanceof onnx.NodeProto) {
this.name = _nodeProto.name;
this.opType = _nodeProto.opType;
this.attributes = new Attribute(_nodeProto.attribute);
} else if (_nodeProto instanceof ortFbs.Node) {
this.name = name ?? _nodeProto.name()!;
this.opType = _nodeProto.opType()!;
this.attributes = new Attribute(ProtoUtil.tensorAttributesFromORTFormat(_nodeProto));
}
this.inputs = [];
this.outputs = [];
this.executeNode = true;
}
name: string;
opType: string;
inputs: number[];
outputs: number[];
attributes: Attribute;
executeNode: boolean;
}
class GraphImpl implements Graph, Graph.Transformer {
private _allData: Value[];
private _allInputIndices: number[];
private _allInputNames: string[];
private _allOutputIndices: number[];
private _allOutputNames: string[];
private _nodes: Node[];
constructor(graph: onnx.IGraphProto|ortFbs.Graph, graphInitializer?: Graph.Initializer) {
if (!graph) {
throw new TypeError('graph is empty');
}
// build the graph - will throw exceptions if something fatal is detected
this.buildGraph(graph);
// execute any transformation logic for the graph (if applicable)
this.transformGraph(graphInitializer);
// check for cycles and other inconsistencies - will throw exceptions if something fatal is detected
this.checkIsAcyclic();
}
getInputIndices(): readonly number[] {
return this._allInputIndices;
}
getInputNames(): readonly string[] {
return this._allInputNames;
}
getOutputIndices(): readonly number[] {
return this._allOutputIndices;
}
getOutputNames(): readonly string[] {
return this._allOutputNames;
}
getValues(): readonly Graph.Value[] {
return this._allData;
}
getNodes(): readonly Graph.Node[] {
return this._nodes;
}
private buildGraph(graph: onnx.IGraphProto|ortFbs.Graph) {
// build the graph - will throw exceptions if something fatal is detected
if (graph instanceof onnx.GraphProto) {
this.buildGraphFromOnnxFormat(graph);
} else if (graph instanceof ortFbs.Graph) {
this.buildGraphFromOrtFormat(graph);
} else {
throw new TypeError('Graph type is not supported.');
}
}
private buildGraphFromOnnxFormat(graph: onnx.IGraphProto) {
const dataIndices = new Map<string, number>();
this._allData = [];
this._allInputIndices = [];
this._allInputNames = [];
this._allOutputIndices = [];
this._allOutputNames = [];
this._nodes = [];
const nodesIndices = new Map<string, number>();
// scan all inputs
if (!graph.input) {
throw new Error('missing information in graph: input');
}
const inputValueNames = [];
for (const i of graph.input) {
if (dataIndices.has(i.name!)) {
throw new Error(`duplicated input name: ${i.name}`);
}
const currentIndex = this._allData.push(new Value(i)) - 1;
dataIndices.set(i.name!, currentIndex);
inputValueNames.push(i.name!);
}
// scan all initializers
if (!graph.initializer) {
throw new Error('missing information in graph: initializer');
}
for (const i of graph.initializer) {
let index = dataIndices.get(i.name!);
if (index === undefined) {
const value = new Value();
value.type = {
shape: {dims: ProtoUtil.tensorDimsFromProto(i.dims!)},
tensorType: ProtoUtil.tensorDataTypeFromProto(i.dataType!)
};
index = this._allData.push(value) - 1;
dataIndices.set(i.name!, index);
}
this._allData[index]._from = -1;
this._allData[index].tensor = Tensor.fromProto(i);
}
// filter out input indices
for (let i = 0; i < this._allData.length; i++) {
if (!this._allData[i].tensor) {
this._allInputIndices.push(i);
this._allInputNames.push(inputValueNames[i]);
}
}
// scan all outputs
if (!graph.output) {
throw new Error('missing information in graph: output');
}
for (const i of graph.output) {
if (dataIndices.has(i.name!)) {
throw new Error(`duplicated output name: ${i.name}`);
}
const currentIndex = this._allData.push(new Value(i)) - 1;
dataIndices.set(i.name!, currentIndex);
this._allOutputIndices.push(currentIndex);
this._allOutputNames.push(i.name!);
}
// scan all nodes
if (!graph.node) {
throw new Error('missing information in graph: node');
}
for (const nodeProto of graph.node) {
if (!nodeProto.name) {
// assign a name to the node if it doesn't have one
for (let pick = 0;; pick++) {
const name = `unnamed_${nodeProto.opType}_${pick}`;
if (!nodesIndices.has(name)) {
nodeProto.name = name;
break;
}
}
}
if (nodesIndices.has(nodeProto.name)) {
throw new Error(`duplicated node name: ${nodeProto.name}`);
}
const currentIndex = this._nodes.push(new Node(nodeProto)) - 1;
nodesIndices.set(nodeProto.name, currentIndex);
}
// scan node's outputs
for (let i = 0; i < this._nodes.length; i++) {
const node = this._nodes[i];
const nodeProto = graph.node[i];
if (!nodeProto.output) {
throw new Error(`missing output for node: ${nodeProto.name}`);
}
for (const output of nodeProto.output) {
let dataIndex = dataIndices.get(output);
if (typeof dataIndex === 'undefined') {
dataIndex = this._allData.push(new Value()) - 1;
dataIndices.set(output, dataIndex);
}
node.outputs.push(dataIndex);
if (this._allData[dataIndex]._from !== undefined) {
throw new Error(`multiple nodes output to one data value: ${dataIndex}`);
}
this._allData[dataIndex]._from = i;
// for the 'Constant' operator, just create a new edge in the graph corresponding to the 'output' of the
// operator and ignore the node from the graph
if (nodeProto.opType === 'Constant') {
if (!nodeProto.attribute || nodeProto.attribute.length !== 1 || !nodeProto.attribute[0].t) {
throw new Error('missing attributes or missing tensor value in attributes for this Constant operator');
}
if (!nodeProto.output || nodeProto.output.length !== 1) {
throw new Error('missing output or incorrect number of outputs for this Constant operator');
}
node.outputs.pop();
node.executeNode = false;
this._allData[dataIndex]._from = -1;
this._allData[dataIndex].tensor = Tensor.fromProto(nodeProto.attribute[0].t);
}
}
}
// scan node's inputs
for (let i = 0; i < this._nodes.length; i++) {
const node = this._nodes[i];
const nodeProto = graph.node[i];
if (!nodeProto.input) {
throw new Error(`missing input for node: ${nodeProto.name}`);
}
for (const input of nodeProto.input) {
const dataIndex = dataIndices.get(input);
if (typeof dataIndex === 'undefined') {
// handle exception when opset > 9 and roi / scales not given
if (input === '' && (nodeProto.input.length === 3 || nodeProto.input.length === 4) &&
nodeProto.opType === 'Resize') {
continue;
}
throw new Error(`unrecognized input '${input}' for node: ${nodeProto.name}`);
}
node.inputs.push(dataIndex);
this._allData[dataIndex]._to.push(i);
}
}
return true;
}
private buildGraphFromOrtFormat(graph: ortFbs.Graph) {
const dataIndices = new Map<string, number>();
this._allData = [];
this._allInputIndices = [];
this._allInputNames = [];
this._allOutputIndices = [];
this._allOutputNames = [];
this._nodes = [];
const nodesIndices = new Map<string, number>();
// scan all inputs
const inputValueNames = [];
for (let i = 0; i < graph.inputsLength(); i++) {
const inputName = graph.inputs(i);
if (dataIndices.has(inputName)) {
throw new Error(`duplicated input name: ${inputName}`);
}
// Find the input typeInfo from nodeargs
for (let j = 0; j < graph.nodeArgsLength(); j++) {
if (graph.nodeArgs(j)?.name() === inputName) {
const value = new Value();
const valueType = graph.nodeArgs(j)?.type()?.valueType();
if (valueType !== ortFbs.TypeInfoValue.tensor_type) {
throw new Error('Unexpected value type for the nodeArg.');
}
const valueInfo = graph.nodeArgs(j)!.type()!.value(new ortFbs.TensorTypeAndShape())!;
const type = ProtoUtil.tensorDataTypeFromProto(valueInfo.elemType());
const shape = valueInfo.shape()!;
const dims = [];
for (let k = 0; k < shape.dimLength()!; k++) {
dims.push(LongUtil.longToNumber(shape.dim(k)!.value()!.dimValue()!));
}
value.type = {shape: {dims}, tensorType: type};
const currentIndex = this._allData.push(value) - 1;
dataIndices.set(inputName, currentIndex);
inputValueNames.push(inputName);
}
}
}
// check initializers
for (let i = 0; i < graph.initializersLength(); i++) {
const initializer = graph.initializers(i)!;
let index = dataIndices.get(initializer.name()!);
if (index === undefined) {
const value = new Value();
const dims = ProtoUtil.tensorDimsFromORTFormat(initializer);
const type = ProtoUtil.tensorDataTypeFromProto(initializer.dataType());
value.type = {shape: {dims}, tensorType: type};
index = this._allData.push(value) - 1;
dataIndices.set(initializer.name()!, index);
}
this._allData[index]._from = -1;
this._allData[index].tensor = Tensor.fromOrtTensor(initializer);
}
// filter out input indices
for (let i = 0; i < this._allData.length; i++) {
if (!this._allData[i].tensor) {
this._allInputIndices.push(i);
this._allInputNames.push(inputValueNames[i]);
}
}
// scan all outputs
for (let i = 0; i < graph.outputsLength(); i++) {
const outputName = graph.outputs(i);
if (dataIndices.has(outputName)) {
throw new Error(`duplicated output name: ${outputName}`);
}
const currentIndex = this._allData.push(new Value()) - 1;
dataIndices.set(outputName, currentIndex);
this._allOutputIndices.push(currentIndex);
this._allOutputNames.push(outputName);
}
// scan all nodes
if (!graph.nodes) {
throw new Error('missing information in graph: node');
}
for (let i = 0; i < graph.nodesLength(); i++) {
const nodeProto = graph.nodes(i);
let name = nodeProto!.name();
if (!name) {
// assign a name to the node if it doesn't have one
for (let pick = 0;; pick++) {
name = `unnamed_${nodeProto!.opType()}_${pick}`;
if (!nodesIndices.has(name)) {
// an unique name is found. break.
break;
}
}
}
if (nodesIndices.has(name)) {
throw new Error(`duplicated node name: ${name}`);
}
const currentIndex = this._nodes.push(new Node(nodeProto!, name)) - 1;
nodesIndices.set(name, currentIndex);
}
// scan node's outputs
for (let i = 0; i < this._nodes.length; i++) {
const node = this._nodes[i];
const nodeProto = graph.nodes(i);
if (nodeProto == null) {
throw new Error(`No node exists at index ${i}`);
}
if (nodeProto?.outputsLength() === 0) {
throw new Error(`missing output for node: ${nodeProto.name}`);
}
for (let j = 0; j < nodeProto?.outputsLength(); j++) {
const output = nodeProto?.outputs(j);
let dataIndex = dataIndices.get(output);
if (typeof dataIndex === 'undefined') {
dataIndex = this._allData.push(new Value()) - 1;
dataIndices.set(output, dataIndex);
}
node.outputs.push(dataIndex);
if (this._allData[dataIndex]._from !== undefined) {
throw new Error(`multiple nodes output to one data value: ${dataIndex}`);
}
this._allData[dataIndex]._from = i;
// for the 'Constant' operator, just create a new edge in the graph corresponding to the 'output' of the
// operator and ignore the node from the graph
if (nodeProto.opType() === 'Constant') {
if (nodeProto.attributesLength() !== 1 || !nodeProto.attributes(0)!.t()) {
throw new Error('missing attributes or missing tensor value in attributes for this Constant operator');
}
if (nodeProto.outputsLength() !== 1) {
throw new Error('missing output or incorrect number of outputs for this Constant operator');
}
node.outputs.pop();
node.executeNode = false;
this._allData[dataIndex]._from = -1;
this._allData[dataIndex].tensor = Tensor.fromOrtTensor(nodeProto.attributes(0)!.t()!);
}
}
}
// scan node's inputs
for (let i = 0; i < this._nodes.length; i++) {
const node = this._nodes[i];
const nodeProto = graph.nodes(i)!;
if (nodeProto.inputsLength() === 0) {
throw new Error(`missing input for node: ${nodeProto.name}`);
}
for (let j = 0; j < nodeProto.inputsLength()!; j++) {
const input = nodeProto.inputs(j)!;
const dataIndex = dataIndices.get(input);
if (typeof dataIndex === 'undefined') {
throw new Error(`unrecognized input '${input}' for node: ${nodeProto!.name()}`);
}
node.inputs.push(dataIndex);
this._allData[dataIndex]._to.push(i);
}
}
}
private checkIsAcyclic() {
// go through the graph and check for cycles or other fatal inconsistencies
const starters: Set<number> = new Set<number>();
this._allInputIndices.forEach(i => {
const data = this._allData[i];
data._to.forEach(j => {
starters.add(j);
});
});
// Iterative DFS to check for cycles
const nodesStack = Array.from(starters);
const nodesState = new Array<string>(this._nodes.length).fill('white');
while (nodesStack.length > 0) {
const nodeIndex = nodesStack.pop()!;
// this node has now been processed completely. Mark this node 'black' to denote this.
if (nodesState[nodeIndex] === 'gray') {
nodesState[nodeIndex] = 'black';
} else {
// this node is under processing stage. mark this node 'gray' to denote this.
nodesStack.push(nodeIndex);
nodesState[nodeIndex] = 'gray';
this._nodes[nodeIndex].outputs.forEach((outgoingEdgeIndex) => {
const data = this._allData[outgoingEdgeIndex];
if (typeof data.tensor !== 'undefined') {
throw new Error('node outputs should not be initialized');
}
if (data._from !== nodeIndex) {
throw new Error('from property of the Value object doesn\'t match index of Node being processed');
}
data._to.forEach((downstreamNodeIndex) => {
// back edge found - cyclic
if (nodesState[downstreamNodeIndex] === 'gray') {
throw new Error('model graph is cyclic');
}
// tree edge found - continue processing by adding it to stack
else if (nodesState[downstreamNodeIndex] === 'white') {
nodesStack.push(downstreamNodeIndex);
}
});
});
}
}
}
private transformGraph(graphInitializer?: Graph.Initializer): void {
// apply common transform
this.removeAllIdentityNodes();
this.removeAllDropoutNodes();
this.fuseConvActivationNodes();
// apply initializer specific transform
if (graphInitializer) {
graphInitializer.transformGraph(this);
}
// finalize graph
this.finalizeGraph();
}
/**
* finalize the graph.
*
* this function should be called after all the transformation completed.
* this function removes all unnecessary nodes and values from the graph
*/
finalizeGraph() {
let offset = 0;
// delete all nodes that are not being executed
// The graph is represented using these two arrays
// this._nodes - Array holding the kernels to execute - each entry is a kernel pointing to this._allData
// this._allData - hold 2 fields - to [] & from - these feileds hold the graph map for inputs and outputs per node
// newIndices - remapping the graph after reading the flag 'executeNode'
const newIndices = new Array<number>(this._nodes.length, 0);
let nodePossition = 0;
for (let i = 0; i < this._nodes.length; i++) {
// giving new indexes to the nodes based on execution flag
newIndices[i] = nodePossition;
if (this._nodes[i].executeNode) {
if (nodePossition !== i) {
this._nodes[nodePossition] = this._nodes[i];
}
nodePossition++;
} else {
// delete all output values
this._nodes[i].outputs.forEach(ind => {
this._allData[ind]._from = -2;
});
}
}
// removing the unused nodes
this._nodes.splice(nodePossition, this._nodes.length - nodePossition);
// Updating this._allData according to the new this._nodes
for (let i = 0; i < this._allData.length; i++) {
const currentData = this._allData[i];
if (currentData._from !== undefined && currentData._from !== -1 && currentData._from !== -2) {
currentData._from = newIndices[currentData._from];
}
for (let j = 0; j < currentData._to.length; j++) {
if (currentData._to[j] >= 0) {
currentData._to[j] = newIndices[currentData._to[j]];
} else {
throw new Error('Trying to update a removed node');
}
}
}
offset = 0;
// delete all values that are not being referenced
for (let i = 0; i < this._allData.length; i++) {
// if current value is neither linked to next node, nor an output value, remove it.
if (this._allData[i].from === -2 && this._allOutputIndices.indexOf(i + offset) === -1) {
offset++;
this._allData.splice(i, 1);
i--;
continue;
}
if (offset > 0) {
let ind = -1;
// if current value is neither an input value nor an initializer, find the node it's
// coming from and update the corresponding node output
if (this._allData[i].from !== undefined && this._allData[i].from !== -1) {
ind = this._nodes[this._allData[i].from].outputs.indexOf(i + offset);
if (ind !== -1) {
this._nodes[this._allData[i].from].outputs[ind] = i;
}
} else {
// if current value is an input value, update its reference in inputIndices
ind = this._allInputIndices.indexOf(i + offset);
if (ind !== -1) {
this._allInputIndices[ind] = i;
}
}
// find the node that the current value is linking to and update its input reference
this._allData[i].to.forEach(node => {
ind = this._nodes[node].inputs.indexOf(i + offset);
if (ind !== -1) {
this._nodes[node].inputs[ind] = i;
}
});
if (this._allData[i].to.length === 0) {
// if current value is a graph output, update its reference in outputIndices
ind = this._allOutputIndices.indexOf(i + offset);
if (ind !== -1) {
this._allOutputIndices[ind] = i;
}
}
}
}
}
/**
* Delete the specifed node. Assume the node has one incoming input and the first output connected to other nodes.
* An input validation must be done before calling this function.
* @param nodeIndex The index of node to be deleted
*/
private deleteNode(nodeIndex: number) {
const node = this._nodes[nodeIndex];
if (node.outputs.length > 1) {
for (let i = 1; i < node.outputs.length; i++) {
if (this._allData[node.outputs[i]].to.length > 0) {
throw new Error('Node deletion with more than one output connected to other nodes is not supported. ');
}
}
}
// this node wil not be executed
node.executeNode = false;
const inputValueIndex = node.inputs[0];
const outputValueIndex = node.outputs[0];
const nodesConsumingOutput = this._allData[outputValueIndex].to;
// remove this node from the to property of the input Value
for (let i = 0; i < node.inputs.length; i++) {
const delIndex = this._allData[node.inputs[i]].to.indexOf(nodeIndex);
// should not happen
if (delIndex === -1) {
throw new Error('The Value object doesn\'t have the current Node in it\'s \'to\' property ');
}
this._allData[node.inputs[i]].to.splice(delIndex, 1);
}
// clear node indices consuming this output Value
this._allData[outputValueIndex]._to = [];
// if the output of this node is a graph output, adjust the index appropriately
const index = this._allOutputIndices.indexOf(outputValueIndex);
if (index !== -1) {
this._allOutputIndices[index] = inputValueIndex;
}
// override the inputs for nodes consuming this node's output with the input to this node
if (nodesConsumingOutput && nodesConsumingOutput.length > 0) {
for (const nodeIndex of nodesConsumingOutput) {
const replaceIndex = this._nodes[nodeIndex].inputs.indexOf(outputValueIndex);
// should not happen
if (replaceIndex === -1) {
throw new Error('The Node object doesn\'t have the output Value in it\'s \'inputs\' property ');
}
this._nodes[nodeIndex].inputs[replaceIndex] = inputValueIndex;
this._allData[inputValueIndex].to.push(nodeIndex);
}
}
}
removeAllDropoutNodes() {
let nodeIndex = 0;
for (const node of this._nodes) {
// weed out 'Dropout' nodes so that no time is wasted in execution
if (node.opType === 'Dropout') {
// the node should have exactly 1 input and 1 or 2 outputs
if (node.inputs.length !== 1) {
throw new Error('Dropout nodes should only contain one input. ');
}
if (node.outputs.length !== 1 && node.outputs.length !== 2) {
throw new Error('Dropout nodes should contain either 1 or 2 output(s)');
}
// the second output should not be referenced by any other node
if (node.outputs.length === 2 && this._allData[node.outputs[1]]._to.length !== 0) {
throw new Error('Dropout nodes\'s second output should not be referenced by other nodes');
}
this.deleteNode(nodeIndex);
}
nodeIndex++;
}
}
removeAllIdentityNodes() {
let nodeIndex = 0;
for (const node of this._nodes) {
// weed out 'Identity' nodes so that no time is wasted in execution
if (node.opType === 'Identity') {
this.deleteNode(nodeIndex);
}
nodeIndex++;
}
}
isActivation(n: Node): boolean {
switch (n.opType) {
// TODO: add other activation methods
case 'Relu':
case 'Sigmoid':
case 'Clip':
return true;
default:
return false;
}
}
fuseConvActivationNodes() {
for (const node of this._nodes) {
if (node.opType === 'Conv') {
const next = this._allData[node.outputs[0]]._to;
if (next.length === 1 && this.isActivation(this._nodes[next[0]])) {
const child = this._nodes[next[0]];
if (child.opType === 'Clip') {
if (child.inputs.length === 1) {
try {
node.attributes.set(
'activation_params', 'floats',
[child.attributes.getFloat('min'), child.attributes.getFloat('max')]);
} catch (e) {
node.attributes.set('activation_params', 'floats', [MIN_CLIP, MAX_CLIP]);
}
} else if (
child.inputs.length >= 3 && this._allData[child.inputs[1]].tensor !== undefined &&
this._allData[child.inputs[2]].tensor !== undefined) {
node.attributes.set('activation_params', 'floats', [
this._allData[child.inputs[1]].tensor!.floatData[0], this._allData[child.inputs[2]].tensor!.floatData[0]
]);
} else {
// Skip fusion with clip node since clip min and clip max are not coming from initializer
continue;
}
}
node.attributes.set('activation', 'string', (child.opType));
this.deleteNode(next[0]);
}
}
}
}
}