diff --git a/wit/wasi-nn.wit b/wit/wasi-nn.wit index 3e54249..fa13f2f 100644 --- a/wit/wasi-nn.wit +++ b/wit/wasi-nn.wit @@ -15,8 +15,21 @@ world ml { import errors; } +/// Inference is performed on a specific `device`. +interface device { + /// Define where tensors reside and graphs execute. + record device { + name: string + } + + /// List the available devices for a given backend. + available-devices: func(backend: backend) -> list; +} + /// All inputs and outputs to an ML inference are represented as `tensor`s. interface tensor { + use device.{device}; + /// The dimensions of a tensor. /// /// The array length matches the tensor rank and each element in the array describes the size of @@ -44,6 +57,7 @@ interface tensor { type tensor-data = list; resource tensor { + /// Construct a tensor that lives on the host CPU. constructor(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data); // Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor @@ -53,7 +67,15 @@ interface tensor { // Describe the type of element in the tensor (e.g., `f32`). ty: func() -> tensor-type; - // Return the tensor data. + // Describe where the tensor is currently located (e.g., `cpu`, `gpu`, `tpu`). + location: func() -> device; + + // Move the tensor to a different device. This operation may result in an expensive data + // copy. + move-to: func(device: device) -> result; + + // Return the tensor data. If the tensor is located on a device other than the CPU, this + // operation may result in an expensive data copy operation. data: func() -> tensor-data; } } @@ -62,8 +84,9 @@ interface tensor { /// framework (e.g., TensorFlow): interface graph { use errors.{error}; - use tensor.{tensor}; + use device.{device}; use inference.{graph-execution-context}; + use tensor.{tensor}; /// An execution graph for performing inference (i.e., a model). resource graph { @@ -81,21 +104,15 @@ interface graph { autodetect, } - /// Define where the graph should be executed. - enum execution-target { - cpu, - gpu, - tpu - } - /// The graph initialization data. /// /// This gets bundled up into an array of buffers because implementing backends may encode their /// graph IR in parts (e.g., OpenVINO stores its IR and weights separately). type graph-builder = list; - /// Load a `graph` from an opaque sequence of bytes to use for inference. - load: func(builder: list, encoding: graph-encoding, target: execution-target) -> result; + /// Load a `graph` from an opaque sequence of bytes to use for inference on the specified device + /// `location`. + load: func(builder: list, encoding: graph-encoding, location: device) -> result; /// Load a `graph` by name. /// @@ -116,6 +133,11 @@ interface inference { /// TODO: this may no longer be necessary in WIT /// (https://github.com/WebAssembly/wasi-nn/issues/43) resource graph-execution-context { + /// Load a tensor using the graph context. Unlike the `tensor` constructor, this function + /// will co-locate the tensor data on a specific device using the graph's underlying + /// backend; this may avoid some copies, improving performance. + load-tensor: func(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data) -> result; + /// Define the inputs to use for inference. set-input: func(name: string, tensor: tensor) -> result<_, error>;