Skip to content

Commit

Permalink
xls/examples/matmul: use unroll_for
Browse files Browse the repository at this point in the history
- use unroll_for for spawn and channel operations
- forward input and result values explicitly in next
- update proc docs

PiperOrigin-RevId: 701969798
  • Loading branch information
proppy authored and copybara-github committed Dec 2, 2024
1 parent b35958e commit ff61af4
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 89 deletions.
8 changes: 3 additions & 5 deletions docs_src/tutorials/how_to_use_procs.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,10 @@ function. A spawn produces no value, hence no `let` on the left-hand side.

### Channel arrays and loop-based spawning

> **Note**: This feature is currently WIP and is not yet available.
Many hardware layouts have regular arrays of components, such as systolic
arrays, vector units, etc. Individually specifying these quickly grows
cumbersome, so users can instead declare arrays of channels and spawn procs
inside `for` loops. This looks as follows:
inside `unroll_for!` loops. This looks as follows:

```dslx-snippet
proc Spawner4x4 {
Expand All @@ -160,8 +158,8 @@ proc Spawner4x4 {
let (input_producers, input_consumers) = chan<F32>[4][4]("node_input");
let (output_producers, output_consumers) = chan<F32>[4][4]("node_output");
for (i, _) : (u32, ()) in range(u32:0, u32:4) {
for (j, _) : (u32, ()) in range(u32:0, u32:4) {
unroll_for! (i, _) : (u32, ()) in range(u32:0, u32:4) {
unroll_for! (j, _) : (u32, ()) in range(u32:0, u32:4) {
spawn Node(input_consumers[i][j],
output_producers[i][j]);
}(());
Expand Down
146 changes: 62 additions & 84 deletions xls/examples/matmul_4x4/matmul_4x4.x
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@

// DSLX implementation of a 4x4 systolic array, appropriate for part of a
// matrix multiplier.

// TODO(rspringer): 2021-09-16, issue #497: The channel declarations here are a
// bit unwieldy; if we can use arrays-of-channels, that'll make things cleaner.

import float32;

type F32 = float32::F32;
Expand Down Expand Up @@ -51,73 +47,69 @@ proc node {
}
}

proc matmul<ROWS: u32, COLS: u32> {
zeroes_out: chan<F32>[COLS] out;
voids_in: chan<F32>[ROWS] in;
proc matmul<ROWS: u32, COLS: u32, ROWS_PLUS_1: u32 = {ROWS + u32:1}, COLS_PLUS_1:
u32 = {
COLS + u32:1}>
{
activations_in: chan<F32>[ROWS] in;
results_out: chan<F32>[COLS] out;
west_inputs: chan<F32>[COLS + u32:1][ROWS] in;
east_outputs: chan<F32>[COLS + u32:1][ROWS] out;
north_inputs: chan<F32>[COLS][ROWS + u32:1] in;
south_outputs: chan<F32>[COLS][ROWS + u32:1] out;

config(activations_in: chan<F32>[ROWS] in, results_out: chan<F32>[COLS] out) {
// Declare the east-to-west channels.
let (east_outputs, west_inputs) = chan<F32>[COLS - u32:1][ROWS]("east_west");

let (east_outputs, west_inputs) = chan<F32>[COLS + u32:1][ROWS]("east_west");
// Declare the north-to-south channels.
let (south_outputs, north_inputs) = chan<F32>[COLS][ROWS - u32:1]("north_south");

// TODO(rspringer): Zeros (as initial partial sums) would be best provided
// by single-value channels.
// Declare the zero-valued initial partial sum channels.
let (zeroes_out, zeroes_in) = chan<F32>[COLS]("zeros");

// Declare void channels for the east-edges of the array.
let (voids_out, voids_in) = chan<F32>[ROWS]("void");

// Spawn all the procs. Specify weights to give a "mul-by-two" matrix.
let f32_0 = float32::zero(false);
let f32_2 = F32 { sign: false, bexp: u8:128, fraction: u23:0 };

// TODO(https://github.com/google/xls/issues/585): We can't loop (and thus
// parameterize) this until we can constexpr evaluate `for` expressions.
spawn node(activations_in[0], zeroes_in[0], east_outputs[0][0], south_outputs[0][0], f32_2);
spawn node(west_inputs[0][0], zeroes_in[1], east_outputs[0][1], south_outputs[0][1], f32_0);
spawn node(west_inputs[0][1], zeroes_in[2], east_outputs[0][2], south_outputs[0][2], f32_0);
spawn node(west_inputs[0][2], zeroes_in[3], voids_out[0], south_outputs[0][3], f32_0);

spawn node(
activations_in[1], north_inputs[0][0], east_outputs[1][0], south_outputs[1][0], f32_0);
spawn node(
west_inputs[1][0], north_inputs[0][1], east_outputs[1][1], south_outputs[1][1], f32_2);
spawn node(
west_inputs[1][1], north_inputs[0][2], east_outputs[1][2], south_outputs[1][2], f32_0);
spawn node(west_inputs[1][2], north_inputs[0][3], voids_out[1], south_outputs[1][3], f32_0);

spawn node(
activations_in[2], north_inputs[1][0], east_outputs[2][0], south_outputs[2][0], f32_0);
spawn node(
west_inputs[2][0], north_inputs[1][1], east_outputs[2][1], south_outputs[2][1], f32_0);
spawn node(
west_inputs[2][1], north_inputs[1][2], east_outputs[2][2], south_outputs[2][2], f32_2);
spawn node(west_inputs[2][2], north_inputs[1][3], voids_out[2], south_outputs[2][3], f32_0);

spawn node(activations_in[3], north_inputs[2][0], east_outputs[3][0], results_out[0], f32_0);
spawn node(west_inputs[3][0], north_inputs[2][1], east_outputs[3][1], results_out[1], f32_0);
spawn node(west_inputs[3][1], north_inputs[2][2], east_outputs[3][2], results_out[2], f32_0);
spawn node(west_inputs[3][2], north_inputs[2][3], voids_out[3], results_out[3], f32_2);

(zeroes_out, voids_in)
let (south_outputs, north_inputs) = chan<F32>[COLS][ROWS + u32:1]("north_south");
unroll_for! (row, _): (u32, ()) in u32:0..ROWS {
unroll_for! (col, _): (u32, ()) in u32:0..COLS {
let weight = F32 {
sign: false,
bexp: if col == row { u8:128 } else { u8:0 },
fraction: u23:0,
};
spawn node(
west_inputs[row][col], north_inputs[row][col], east_outputs[row][col + u32:1],
south_outputs[row + u32:1][col], weight);
}(());
}(());
(activations_in, results_out, west_inputs, east_outputs, north_inputs, south_outputs)
}

init { () }

// All we need to do is to push in "zero" values to the top of the array and consume void
// channels to keep the system moving.
next(state: ()) {
send(join(), zeroes_out[0], float32::zero(false));
send(join(), zeroes_out[1], float32::zero(false));
send(join(), zeroes_out[2], float32::zero(false));
send(join(), zeroes_out[3], float32::zero(false));
recv(join(), voids_in[0]);
recv(join(), voids_in[1]);
recv(join(), voids_in[2]);
recv(join(), voids_in[3]);
// Send activation to the "left"-end of the array.
let activations_col = u32:0;
unroll_for! (row, _): (u32, ()) in u32:0..ROWS {
let (tok, activation) = recv(join(), activations_in[row]);
send(tok, east_outputs[row][activations_col], activation);
}(());

// Send zero values to the "top"-end of the array.
let zeroes_row = u32:0;
unroll_for! (col, _): (u32, ()) in u32:0..COLS {
send(join(), south_outputs[zeroes_row][col], float32::zero(false));
}(());

// Consume and drop values on the "right"-end of the array.
// TODO - google/xls#1750: remove unroll_for! workaround.
unroll_for! (drops_col, _): (u32, ()) in COLS..COLS + u32:1 {
unroll_for! (row, _): (u32, ()) in u32:0..ROWS {
recv(join(), west_inputs[row][drops_col]);
}(());
}(());

// Forward result from the "bottom"-end of the array.
// TODO - google/xls#1750: remove unroll_for! workaround.
unroll_for! (results_row, _): (u32, ()) in ROWS..ROWS + u32:1 {
unroll_for! (col, _): (u32, ()) in u32:0..COLS {
let (tok, result) = recv(join(), north_inputs[results_row][col]);
send(tok, results_out[col], result);
}(());
}(());
}
}

Expand All @@ -140,7 +132,7 @@ proc test_proc {
config(terminator: chan<bool> out) {
let (activations_out, activations_in) = chan<F32>[4]("activations");
let (results_out, results_in) = chan<F32>[4]("results");
spawn matmul<u32:4, u32:4>(activations_in, results_out);
spawn matmul_4x4(activations_in, results_out);
(activations_out, results_in, terminator)
}

Expand All @@ -152,32 +144,18 @@ proc test_proc {
let f32_4 = F32 { sign: false, bexp: u8:129, fraction: u23:0 };

// Send the desired inputs.
let tok = for (i, tok): (u32, token) in range(u32:0, u32:4) {
send(tok, activations_out[i], f32_2)
let tok = unroll_for! (i, tok): (u32, token) in u32:0..u32:4 {
send(tok, activations_out[i], f32_2);
tok
}(join());

// Send extra inputs to keep the system moving while our results are processing.
let tok = for (_, tok): (u32, token) in range(u32:0, u32:4) {
for (i, tok): (u32, token) in range(u32:0, u32:4) {
send(tok, activations_out[i], f32_0)
}(tok)
}(tok);

// Flush the intermediate values.
let tok = for (_, tok): (u32, token) in range(u32:0, u32:0) {
for (i, tok): (u32, token) in range(u32:0, u32:4) {
let (tok, _) = recv(tok, results_in[i]);
tok
}(tok)
}(tok);

let (tok, value) = recv(tok, results_in[0]);
let (_, value) = recv(tok, results_in[0]);
assert_eq(value, f32_0);
let (tok, value) = recv(tok, results_in[1]);
let (_, value) = recv(tok, results_in[1]);
assert_eq(value, f32_0);
let (tok, value) = recv(tok, results_in[2]);
let (_, value) = recv(tok, results_in[2]);
assert_eq(value, f32_0);
let (tok, value) = recv(tok, results_in[3]);
let (_, value) = recv(tok, results_in[3]);
assert_eq(value, f32_4);

let tok = send(tok, terminator, true);
Expand Down

0 comments on commit ff61af4

Please sign in to comment.