Skip to content

Commit

Permalink
Apply MTR's feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
abrown committed Apr 1, 2024
1 parent b2594a9 commit af4495e
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions wit/wasi-nn.wit
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@ world ml {
/// Inference is performed on a specific `device`.
interface device {
/// Define where tensors reside and graphs execute.
enum location {
cpu,
gpu,
tpu
record device {
name: string
}

/// List the available devices for a given backend.
available-devices: func(backend: backend) -> list<device>;
}

/// All inputs and outputs to an ML inference are represented as `tensor`s.
interface tensor {
use device.{location};
use device.{device};

/// The dimensions of a tensor.
///
Expand Down Expand Up @@ -67,26 +68,23 @@ interface tensor {
ty: func() -> tensor-type;

// Describe where the tensor is currently located (e.g., `cpu`, `gpu`, `tpu`).
location: func() -> location;
location: func() -> device;

// Move the tensor to a different device. This operation may result in an expensive data
// copy.
move-to: func(device: device) -> result<tensor, error>;

// Return the tensor data. If the tensor is located on a device other than the CPU, this
// operation may result in an expensive data copy operation.
data: func() -> tensor-data;
}

/// Alternately, construct a tensor that lives exclusively on a specific device.
create_on_device: func(dimensions: tensor-dimensions, ty: tensor-type, data: tensor-data,
location: execution-target, backend: graph-encoding) -> result<tensor, error>;

// TODO: rename exection-target to... device?
// TODO: rename graph-encoding to... backend?
}

/// A `graph` is a loaded instance of a specific ML model (e.g., MobileNet) for a specific ML
/// framework (e.g., TensorFlow):
interface graph {
use errors.{error};
use device.{location};
use device.{device};
use inference.{graph-execution-context};
use tensor.{tensor};

Expand Down Expand Up @@ -114,7 +112,7 @@ interface graph {

/// Load a `graph` from an opaque sequence of bytes to use for inference on the specified device
/// `location`.
load: func(builder: list<graph-builder>, encoding: graph-encoding, location: location) -> result<graph, error>;
load: func(builder: list<graph-builder>, encoding: graph-encoding, location: device) -> result<graph, error>;

/// Load a `graph` by name.
///
Expand Down

0 comments on commit af4495e

Please sign in to comment.