diff --git a/docs_src/tutorials/how_to_use_procs.md b/docs_src/tutorials/how_to_use_procs.md index 4bae531590..07299666df 100644 --- a/docs_src/tutorials/how_to_use_procs.md +++ b/docs_src/tutorials/how_to_use_procs.md @@ -142,12 +142,10 @@ function. A spawn produces no value, hence no `let` on the left-hand side. ### Channel arrays and loop-based spawning -> **Note**: This feature is currently WIP and is not yet available. - Many hardware layouts have regular arrays of components, such as systolic arrays, vector units, etc. Individually specifying these quickly grows cumbersome, so users can instead declare arrays of channels and spawn procs -inside `for` loops. This looks as follows: +inside `unroll_for!` loops. This looks as follows: ```dslx-snippet proc Spawner4x4 { @@ -160,8 +158,8 @@ proc Spawner4x4 { let (input_producers, input_consumers) = chan[4][4]("node_input"); let (output_producers, output_consumers) = chan[4][4]("node_output"); - for (i, _) : (u32, ()) in range(u32:0, u32:4) { - for (j, _) : (u32, ()) in range(u32:0, u32:4) { + unroll_for! (i, _) : (u32, ()) in range(u32:0, u32:4) { + unroll_for! (j, _) : (u32, ()) in range(u32:0, u32:4) { spawn Node(input_consumers[i][j], output_producers[i][j]); }(()); diff --git a/xls/examples/matmul_4x4/matmul_4x4.x b/xls/examples/matmul_4x4/matmul_4x4.x index 416f0c0664..6db42022ad 100644 --- a/xls/examples/matmul_4x4/matmul_4x4.x +++ b/xls/examples/matmul_4x4/matmul_4x4.x @@ -14,10 +14,6 @@ // DSLX implementation of a 4x4 systolic array, appropriate for part of a // matrix multiplier. - -// TODO(rspringer): 2021-09-16, issue #497: The channel declarations here are a -// bit unwieldy; if we can use arrays-of-channels, that'll make things cleaner. - import float32; type F32 = float32::F32; @@ -51,73 +47,69 @@ proc node { } } -proc matmul { - zeroes_out: chan[COLS] out; - voids_in: chan[ROWS] in; +proc matmul +{ + activations_in: chan[ROWS] in; + results_out: chan[COLS] out; + west_inputs: chan[COLS + u32:1][ROWS] in; + east_outputs: chan[COLS + u32:1][ROWS] out; + north_inputs: chan[COLS][ROWS + u32:1] in; + south_outputs: chan[COLS][ROWS + u32:1] out; config(activations_in: chan[ROWS] in, results_out: chan[COLS] out) { // Declare the east-to-west channels. - let (east_outputs, west_inputs) = chan[COLS - u32:1][ROWS]("east_west"); - + let (east_outputs, west_inputs) = chan[COLS + u32:1][ROWS]("east_west"); // Declare the north-to-south channels. - let (south_outputs, north_inputs) = chan[COLS][ROWS - u32:1]("north_south"); - - // TODO(rspringer): Zeros (as initial partial sums) would be best provided - // by single-value channels. - // Declare the zero-valued initial partial sum channels. - let (zeroes_out, zeroes_in) = chan[COLS]("zeros"); - - // Declare void channels for the east-edges of the array. - let (voids_out, voids_in) = chan[ROWS]("void"); - - // Spawn all the procs. Specify weights to give a "mul-by-two" matrix. - let f32_0 = float32::zero(false); - let f32_2 = F32 { sign: false, bexp: u8:128, fraction: u23:0 }; - - // TODO(https://github.com/google/xls/issues/585): We can't loop (and thus - // parameterize) this until we can constexpr evaluate `for` expressions. - spawn node(activations_in[0], zeroes_in[0], east_outputs[0][0], south_outputs[0][0], f32_2); - spawn node(west_inputs[0][0], zeroes_in[1], east_outputs[0][1], south_outputs[0][1], f32_0); - spawn node(west_inputs[0][1], zeroes_in[2], east_outputs[0][2], south_outputs[0][2], f32_0); - spawn node(west_inputs[0][2], zeroes_in[3], voids_out[0], south_outputs[0][3], f32_0); - - spawn node( - activations_in[1], north_inputs[0][0], east_outputs[1][0], south_outputs[1][0], f32_0); - spawn node( - west_inputs[1][0], north_inputs[0][1], east_outputs[1][1], south_outputs[1][1], f32_2); - spawn node( - west_inputs[1][1], north_inputs[0][2], east_outputs[1][2], south_outputs[1][2], f32_0); - spawn node(west_inputs[1][2], north_inputs[0][3], voids_out[1], south_outputs[1][3], f32_0); - - spawn node( - activations_in[2], north_inputs[1][0], east_outputs[2][0], south_outputs[2][0], f32_0); - spawn node( - west_inputs[2][0], north_inputs[1][1], east_outputs[2][1], south_outputs[2][1], f32_0); - spawn node( - west_inputs[2][1], north_inputs[1][2], east_outputs[2][2], south_outputs[2][2], f32_2); - spawn node(west_inputs[2][2], north_inputs[1][3], voids_out[2], south_outputs[2][3], f32_0); - - spawn node(activations_in[3], north_inputs[2][0], east_outputs[3][0], results_out[0], f32_0); - spawn node(west_inputs[3][0], north_inputs[2][1], east_outputs[3][1], results_out[1], f32_0); - spawn node(west_inputs[3][1], north_inputs[2][2], east_outputs[3][2], results_out[2], f32_0); - spawn node(west_inputs[3][2], north_inputs[2][3], voids_out[3], results_out[3], f32_2); - - (zeroes_out, voids_in) + let (south_outputs, north_inputs) = chan[COLS][ROWS + u32:1]("north_south"); + unroll_for! (row, _): (u32, ()) in u32:0..ROWS { + unroll_for! (col, _): (u32, ()) in u32:0..COLS { + let weight = F32 { + sign: false, + bexp: if col == row { u8:128 } else { u8:0 }, + fraction: u23:0, + }; + spawn node( + west_inputs[row][col], north_inputs[row][col], east_outputs[row][col + u32:1], + south_outputs[row + u32:1][col], weight); + }(()); + }(()); + (activations_in, results_out, west_inputs, east_outputs, north_inputs, south_outputs) } init { () } - // All we need to do is to push in "zero" values to the top of the array and consume void - // channels to keep the system moving. next(state: ()) { - send(join(), zeroes_out[0], float32::zero(false)); - send(join(), zeroes_out[1], float32::zero(false)); - send(join(), zeroes_out[2], float32::zero(false)); - send(join(), zeroes_out[3], float32::zero(false)); - recv(join(), voids_in[0]); - recv(join(), voids_in[1]); - recv(join(), voids_in[2]); - recv(join(), voids_in[3]); + // Send activation to the "left"-end of the array. + let activations_col = u32:0; + unroll_for! (row, _): (u32, ()) in u32:0..ROWS { + let (tok, activation) = recv(join(), activations_in[row]); + send(tok, east_outputs[row][activations_col], activation); + }(()); + + // Send zero values to the "top"-end of the array. + let zeroes_row = u32:0; + unroll_for! (col, _): (u32, ()) in u32:0..COLS { + send(join(), south_outputs[zeroes_row][col], float32::zero(false)); + }(()); + + // Consume and drop values on the "right"-end of the array. + // TODO - google/xls#1750: remove unroll_for! workaround. + unroll_for! (drops_col, _): (u32, ()) in COLS..COLS + u32:1 { + unroll_for! (row, _): (u32, ()) in u32:0..ROWS { + recv(join(), west_inputs[row][drops_col]); + }(()); + }(()); + + // Forward result from the "bottom"-end of the array. + // TODO - google/xls#1750: remove unroll_for! workaround. + unroll_for! (results_row, _): (u32, ()) in ROWS..ROWS + u32:1 { + unroll_for! (col, _): (u32, ()) in u32:0..COLS { + let (tok, result) = recv(join(), north_inputs[results_row][col]); + send(tok, results_out[col], result); + }(()); + }(()); } } @@ -140,7 +132,7 @@ proc test_proc { config(terminator: chan out) { let (activations_out, activations_in) = chan[4]("activations"); let (results_out, results_in) = chan[4]("results"); - spawn matmul(activations_in, results_out); + spawn matmul_4x4(activations_in, results_out); (activations_out, results_in, terminator) } @@ -152,32 +144,18 @@ proc test_proc { let f32_4 = F32 { sign: false, bexp: u8:129, fraction: u23:0 }; // Send the desired inputs. - let tok = for (i, tok): (u32, token) in range(u32:0, u32:4) { - send(tok, activations_out[i], f32_2) + let tok = unroll_for! (i, tok): (u32, token) in u32:0..u32:4 { + send(tok, activations_out[i], f32_2); + tok }(join()); - // Send extra inputs to keep the system moving while our results are processing. - let tok = for (_, tok): (u32, token) in range(u32:0, u32:4) { - for (i, tok): (u32, token) in range(u32:0, u32:4) { - send(tok, activations_out[i], f32_0) - }(tok) - }(tok); - - // Flush the intermediate values. - let tok = for (_, tok): (u32, token) in range(u32:0, u32:0) { - for (i, tok): (u32, token) in range(u32:0, u32:4) { - let (tok, _) = recv(tok, results_in[i]); - tok - }(tok) - }(tok); - - let (tok, value) = recv(tok, results_in[0]); + let (_, value) = recv(tok, results_in[0]); assert_eq(value, f32_0); - let (tok, value) = recv(tok, results_in[1]); + let (_, value) = recv(tok, results_in[1]); assert_eq(value, f32_0); - let (tok, value) = recv(tok, results_in[2]); + let (_, value) = recv(tok, results_in[2]); assert_eq(value, f32_0); - let (tok, value) = recv(tok, results_in[3]); + let (_, value) = recv(tok, results_in[3]); assert_eq(value, f32_4); let tok = send(tok, terminator, true);