Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reland "webnn: Remove sync methods" #44626

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 13 additions & 24 deletions webnn/idlharness.https.any.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,18 @@ idl_test(
MLGraph: ['graph']
});

for (const executionType of ExecutionArray) {
const isSync = executionType === 'sync';
if (self.GLOBAL.isWindow() && isSync) {
continue;
}

if (isSync) {
self.context = navigator.ml.createContextSync();
} else {
self.context = await navigator.ml.createContext();
}

self.builder = new MLGraphBuilder(self.context);
self.input = builder.input('input', {dataType: 'float32', dimensions: [1, 1, 5, 5]});
self.filter = builder.constant({dataType: 'float32', dimensions: [1, 1, 3, 3]}, new Float32Array(9).fill(1));
self.relu = builder.relu();
self.output = builder.conv2d(input, filter, {activation: relu, inputLayout: "nchw"});

if (isSync) {
self.graph = builder.buildSync({output});
} else {
self.graph = await builder.build({output});
}
}
self.context = await navigator.ml.createContext();

self.builder = new MLGraphBuilder(self.context);
self.input =
builder.input('input', {dataType: 'float32', dimensions: [1, 1, 5, 5]});
self.filter = builder.constant(
{dataType: 'float32', dimensions: [1, 1, 3, 3]},
new Float32Array(9).fill(1));
self.relu = builder.relu();
self.output =
builder.conv2d(input, filter, {activation: relu, inputLayout: "nchw"});

self.graph = await builder.build({output});
}
);
73 changes: 15 additions & 58 deletions webnn/resources/utils.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
'use strict';

const ExecutionArray = ['sync', 'async'];

// https://webmachinelearning.github.io/webnn/#enumdef-mloperanddatatype
const TypedArrayDict = {
// workaround use Uint16 for Float16
Expand Down Expand Up @@ -793,25 +791,7 @@ const buildGraph = (operationName, builder, resources, buildFunc) => {
};

/**
* Build a graph, synchronously compile graph and execute, then check computed results.
* @param {String} operationName - An operation name
* @param {MLContext} context - A ML context
* @param {MLGraphBuilder} builder - A ML graph builder
* @param {Object} resources - Resources used for building a graph
* @param {Function} buildFunc - A build function for an operation
*/
const runSync = (operationName, context, builder, resources, buildFunc) => {
// build a graph
const [namedOutputOperands, inputs, outputs] = buildGraph(operationName, builder, resources, buildFunc);
// synchronously compile the graph up to the output operand
const graph = builder.buildSync(namedOutputOperands);
// synchronously execute the compiled graph.
context.computeSync(graph, inputs, outputs);
checkResults(operationName, namedOutputOperands, outputs, resources);
};

/**
* Build a graph, asynchronously compile graph and execute, then check computed results.
* Build a graph, compile graph and execute, then check computed results.
* @param {String} operationName - An operation name
* @param {MLContext} context - A ML context
* @param {MLGraphBuilder} builder - A ML graph builder
Expand All @@ -821,9 +801,9 @@ const runSync = (operationName, context, builder, resources, buildFunc) => {
const run = async (operationName, context, builder, resources, buildFunc) => {
// build a graph
const [namedOutputOperands, inputs, outputs] = buildGraph(operationName, builder, resources, buildFunc);
// asynchronously compile the graph up to the output operand
// compile the graph up to the output operand
const graph = await builder.build(namedOutputOperands);
// asynchronously execute the compiled graph
// execute the compiled graph
const result = await context.compute(graph, inputs, outputs);
checkResults(operationName, namedOutputOperands, result.outputs, resources);
};
Expand All @@ -842,41 +822,18 @@ const testWebNNOperation = (operationName, buildFunc, deviceType = 'cpu') => {
operationNameArray = operationName;
}

ExecutionArray.forEach(executionType => {
const isSync = executionType === 'sync';
if (self.GLOBAL.isWindow() && isSync) {
return;
}
let context;
let builder;
if (isSync) {
// test sync
operationNameArray.forEach((subOperationName) => {
const tests = loadTests(subOperationName);
setup(() => {
context = navigator.ml.createContextSync({deviceType});
builder = new MLGraphBuilder(context);
});
for (const subTest of tests) {
test(() => {
runSync(subOperationName, context, builder, subTest, buildFunc);
}, `${subTest.name} / ${executionType}`);
}
});
} else {
// test async
operationNameArray.forEach((subOperationName) => {
const tests = loadTests(subOperationName);
promise_setup(async () => {
context = await navigator.ml.createContext({deviceType});
builder = new MLGraphBuilder(context);
});
for (const subTest of tests) {
promise_test(async () => {
await run(subOperationName, context, builder, subTest, buildFunc);
}, `${subTest.name} / ${executionType}`);
}
});
let context;
let builder;
operationNameArray.forEach((subOperationName) => {
const tests = loadTests(subOperationName);
promise_setup(async () => {
context = await navigator.ml.createContext({deviceType});
builder = new MLGraphBuilder(context);
});
for (const subTest of tests) {
promise_test(async () => {
await run(subOperationName, context, builder, subTest, buildFunc);
}, `${subTest.name}`);
}
});
};
Expand Down
Loading