Skip to content

Commit

Permalink
Update executorch.js (#1175)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Jan 18, 2025
1 parent 8d05b70 commit 73ab4d0
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 36 deletions.
104 changes: 69 additions & 35 deletions source/executorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,14 @@ executorch.Graph = class {
this.outputs = [];
this.nodes = [];
const values = new Map();
values.map = (arg) => {
if (!values.has(arg)) {
const v = plan.values[arg].val;
if (v instanceof executorch.schema.Tensor || v instanceof executorch.schema.TensorList) {
values.map = (index, output) => {
if (!values.has(index)) {
const v = plan.values[index].val;
const tensor = v instanceof executorch.schema.Tensor || v instanceof executorch.schema.TensorList;
if (output && !tensor) {
const value = [new executorch.Value(index.toString(), null, null)];
values.set(index, { type: null, value });
} else if (tensor) {
const tensors = v instanceof executorch.schema.Tensor ? [v] : Array.from(v.items).map((arg) => plan.values[arg].val);
const list = [];
for (let i = 0; i < tensors.length; i++) {
Expand All @@ -62,26 +66,27 @@ executorch.Graph = class {
if (v.data_buffer_idx > 0) {
initializer = new executorch.Tensor(tensor);
}
const identifier = tensors.length > 1 ? `${arg}.${i}` : arg.toString();
list.push(new executorch.Value(identifier, type, initializer));
const identifier = tensors.length > 1 ? `${index}.${i}` : index.toString();
const value = new executorch.Value(identifier, type, initializer);
list.push(value);
}
values.set(arg, { type: null, value: list });
values.set(index, { type: null, value: list });
} else if (v instanceof executorch.schema.Bool) {
values.set(arg, { type: 'int64', value: v.bool_val });
values.set(index, { type: 'int64', value: v.bool_val });
} else if (v instanceof executorch.schema.Int) {
values.set(arg, { type: 'int64', value: v.int_val });
values.set(index, { type: 'int64', value: v.int_val });
} else if (v instanceof executorch.schema.IntList) {
const list = v.items.map((index) => plan.values[index].val.int_val);
values.set(arg, { type: 'int64[]', value: list });
values.set(index, { type: 'int64[]', value: list });
} else if (v instanceof executorch.schema.Double) {
values.set(arg, { type: 'float64', value: v.double_val });
values.set(index, { type: 'float64', value: v.double_val });
} else if (v instanceof executorch.schema.Null) {
values.set(arg, { type: 'attribute', value: null });
values.set(index, { type: 'attribute', value: null });
} else {
throw new Error('Value type not implemented.');
}
}
return values.get(arg);
return values.get(index);
};
for (const input of plan.inputs) {
const value = values.map(input);
Expand Down Expand Up @@ -128,13 +133,17 @@ executorch.Node = class {
this.name = '';
this.inputs = [];
this.outputs = [];
this.attributes = [];
const instr_args = instruction.instr_args;
if (instr_args instanceof executorch.schema.KernelCall) {
const op = plan.operators[instr_args.op_index];
const name = op.name.split('::').pop();
const identifier = op.overload ? `${op.name}.${op.overload}` : op.name;
const schemas = execution.invoke('torch._C._jit_get_schemas_for_operator', [op.name]);
const schema = schemas.find((schema) => schema.name === op.name && schema.overload_name === op.overload);
if (!schema) {
throw new executorch.Error(`Operator schema for '${identifier}' not found.`);
}
const category = schema && schema.category ? schema.category : '';
const alias = (arg) => arg && arg.alias_info && arg.alias_info.before_set.length === 1 ? arg.alias_info.before_set[0] : null;
const outputs = new Set(schema && Array.isArray(schema.returns) ? schema.returns.map((arg) => alias(arg)).filter((alias) => alias !== null) : []);
Expand All @@ -143,57 +152,82 @@ executorch.Node = class {
let i = 0;
const args = instr_args.args;
for (; i < schema.arguments.length; i++) {
const v = args[i];
const index = args[i];
const arg = schema && i < schema.arguments.length ? schema.arguments[i] : null;
const output = arg ? alias(schema.arguments[i]) : null;
if (output && outputs.has(output)) {
inputs.set(output, v);
inputs.set(output, index);
continue;
}
const name = arg ? arg.name : i.toString();
const value = values.map(v);
const value = values.map(index);
const argument = new executorch.Argument(name, value.value, value.type);
this.inputs.push(argument);
}
for (let j = 0; j < schema.returns.length; j++) {
const ret = schema.returns[j];
const output = alias(ret);
const v = output && inputs.has(output) ? inputs.get(output) : args[i++];
let index = args[i++];
index = output && inputs.has(output) ? inputs.get(output) : index;
const name = ret.name;
const value = values.map(v);
const value = values.map(index, true);
const argument = new executorch.Argument(name || '', value.value, value.type);
this.outputs.push(argument);
}
} else if (instr_args instanceof executorch.schema.DelegateCall) {
const delegate = plan.delegates[instr_args.delegate_index];
const args = instr_args.args;
const name = delegate.id;
this.type = { name };
switch (name) {
case 'XnnpackBackend': {
const input = values.map(args[0]);
const output = values.map(args[1], true);
this.inputs.push(new executorch.Argument('input', input.value, input.type));
this.outputs.push(new executorch.Argument('output', output.value, output.type));
break;
}
case 'CoreMLBackend': {
const input = values.map(args[0]);
const output = values.map(args[1], true);
this.inputs.push(new executorch.Argument('input', input.value, input.type));
this.outputs.push(new executorch.Argument('output', output.value, output.type));
break;
}
default: {
throw new executorch.Error(`ExecuTorch delegate '${name}' not implemented.`);
}
}
for (const spec of delegate.compile_specs) {
const value = ArrayBuffer.isView(spec.value) ? Array.from(spec.value) : spec.value;
const attribute = new executorch.Argument(spec.key, value);
this.attributes.push(attribute);
}
} else {
throw new Error('Instruction argument not implemented.');
throw new Error(`Instruction type '${instr_args.constructor.name}' not implemented.`);
}
}
};

executorch.TensorType = class {

constructor(tensor) {
const ScalarType = executorch.schema.ScalarType;
switch (tensor.scalar_type) {

case ScalarType.BOOL: this.dataType = 'boolean'; break;
case ScalarType.BYTE: this.dataType = 'uint8'; break;
case ScalarType.CHAR: this.dataType = 'int8'; break;
case ScalarType.SHORT: this.dataType = 'int16'; break;
case ScalarType.INT: this.dataType = 'int32'; break;
case ScalarType.LONG: this.dataType = 'int64'; break;
case ScalarType.HALF: this.dataType = 'float16'; break;
case ScalarType.FLOAT: this.dataType = 'float32'; break;
case ScalarType.DOUBLE: this.dataType = 'float64'; break;
case ScalarType.UINT16: this.dataType = 'uint16'; break;
case ScalarType.UINT32: this.dataType = 'uint32'; break;
case ScalarType.UINT64: this.dataType = 'uint64'; break;
default: throw new executorch.Error(`Unknown tensor data type '${tensor.scalar_type}'.`);
executorch.TensorType._types = executorch.TensorType._types || [
'uint8',
'int8', 'int16', 'int32', 'int64',
'float16', 'float32', 'float64',
'complex16', 'complex32', 'complex64',
'boolean',
'qint8', 'quint8', 'qint32',
'bfloat16',
'quint4x2', 'quint2x4', 'bits1x8', 'bits2x4', 'bits4x2', 'bits8', 'bits16',
'float8e5m2', 'float8e4m3fn', 'float8e5m2fnuz', 'float8e4m3fnuz',
'uint16', 'uint32', 'uint64'
];
if (tensor.scalar_type >= executorch.TensorType._types.length) {
throw new executorch.Error(`Unknown tensor data type '${tensor.scalar_type}'.`);
}
this.dataType = executorch.TensorType._types.length[tensor.scalar_type];
this.shape = new executorch.TensorShape(Array.from(tensor.sizes));
}

Expand Down
2 changes: 1 addition & 1 deletion source/python.js
Original file line number Diff line number Diff line change
Expand Up @@ -18340,7 +18340,7 @@ python.Execution = class {
torch.quint2x4 = new torch.dtype(17, 'quint2x4');
torch.bits1x8 = new torch.dtype(18, 'bits1x8');
torch.bits2x4 = new torch.dtype(19, 'bits2x4');
torch.bits2x4 = new torch.dtype(20, 'bits2x4');
torch.bits4x2 = new torch.dtype(20, 'bits4x2');
torch.bits8 = new torch.dtype(21, 'bits8');
torch.bits16 = new torch.dtype(22, 'bits16');
torch.float8_e5m2 = new torch.dtype(23, 'float8_e5m2', 1);
Expand Down
Loading

0 comments on commit 73ab4d0

Please sign in to comment.