Skip to content

Commit

Permalink
feat(simd): add new dot fns, tests, rename
Browse files Browse the repository at this point in the history
  • Loading branch information
postspectacular committed Oct 19, 2019
1 parent 2f50df6 commit 50bc9fc
Show file tree
Hide file tree
Showing 7 changed files with 286 additions and 24 deletions.
12 changes: 12 additions & 0 deletions packages/simd/assembly/align.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// @ts-ignore: decorator
@inline
export function align(x: usize, base: usize): usize {
base--;
return (x + base) & ~base;
}

// @ts-ignore: decorator
@inline
export function isAligned(x: usize, base: usize): boolean {
return (x & (base - 1)) === 0;
}
81 changes: 79 additions & 2 deletions packages/simd/assembly/dot.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,45 @@
/**
* f32x4 dot product. `so` should be 1 for packed result buffer.
* Takes two densely packed vec2 AOS buffers `a` and `b`, computes their
* 2D dot products and stores results in `out`. Computes two results per
* iteration, hence `num` must be an even number or else the last vector
* will not be processed. `so` should be 1 for packed result buffer.
*
* `a` & `b` should be aligned to 16, `out` to multiples of 4.
*
* @param out
* @param a
* @param b
* @param num
* @param so
*/
export function dot2_f32_aos(
out: usize,
a: usize,
b: usize,
num: usize,
so: usize
): usize {
const res = out;
const so2 = so << 3;
so <<= 2;
num >>= 1;
for (; num-- > 0; ) {
let m = v128.mul<f32>(v128.load(a), v128.load(b));
m = v128.add<f32>(m, v128.shuffle<f32>(m, m, 1, 0, 3, 2));
store<f32>(out, v128.extract_lane<f32>(m, 0));
store<f32>(out + so, v128.extract_lane<f32>(m, 2));
out += so2;
a += 16;
b += 16;
}
return res;
}

/**
* Takes two vec4 AOS buffers, computes their dot products and stores
* results in `out`. `so` should be 1 for packed result buffer. `sa` and
* `sb` indicate the stride lengths (in floats) between each vector in
* each respective buffer and should be a multiple of 4.
*
* @param out
* @param a
Expand All @@ -9,7 +49,7 @@
* @param sa
* @param sb
*/
export function dot4(
export function dot4_f32_aos(
out: usize,
a: usize,
b: usize,
Expand All @@ -36,3 +76,40 @@ export function dot4(
}
return res;
}

export function dot4_f32_soa(
out: usize,
a: usize,
b: usize,
num: usize,
sa: usize,
sb: usize
): usize {
sa <<= 2;
sb <<= 2;
num >>= 2;
const sa2 = sa * 2;
const sb2 = sb * 2;
const sa3 = sa * 3;
const sb3 = sb * 3;
const res = out;
for (; num-- > 0; ) {
v128.store(
out,
v128.add<f32>(
v128.add<f32>(
v128.add<f32>(
v128.mul<f32>(v128.load(a), v128.load(b)),
v128.mul<f32>(v128.load(a + sa), v128.load(b + sb))
),
v128.mul<f32>(v128.load(a + sa2), v128.load(b + sb2))
),
v128.mul<f32>(v128.load(a + sa3), v128.load(b + sb3))
)
);
out += 16;
sa += 16;
sb += 16;
}
return res;
}
19 changes: 10 additions & 9 deletions packages/simd/assembly/madd.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
/**
* f32x4 multiply-add: out = a * b + c.
* Takes three vec4 buffers, computes componentwise a * b + c and stores
* results in `out`. Both AOS / SOA layouts are supported, as long as
* all buffers are using the same layout.
*
* `num` and all strides must by multiples of 4.
* All pointers must be aligned to multiples of 16.
* Returns `out` pointer.
* All strides must by multiples of 4. All pointers must be aligned to
* multiples of 16. Returns `out` pointer.
*
* @param out
* @param a
* @param b
* @param c
* @param num number of 4D vectors
* @param num number of vec4
* @param so out element stride
* @param sa
* @param sb
* @param sc
* @param sa A element stride
* @param sb B element stride
* @param sc C element stride
*/
export function madd4(
export function madd4_f32(
out: usize,
a: usize,
b: usize,
Expand Down
2 changes: 1 addition & 1 deletion packages/simd/assembly/maddn.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export function maddN4(
export function maddN4_f32(
out: usize,
a: usize,
b: f32,
Expand Down
6 changes: 3 additions & 3 deletions packages/simd/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
"build:release": "yarn clean && yarn build:wasm && yarn build:es6 && node ../../scripts/bundle-module all",
"build:es6": "tsc --declaration",
"build:test": "rimraf build && tsc -p test/tsconfig.json",
"build:wasm": "asc assembly/index.ts -b simd.wasm -t simd.wat --validate --optimize --enable simd --runtime none --importMemory --memoryBase 1024",
"test": "yarn build:test && mocha build/test/*.js",
"cover": "yarn build:test && nyc mocha build/test/*.js && nyc report --reporter=lcov",
"build:wasm": "asc assembly/index.ts -b simd.wasm -t simd.wat --validate --optimize --enable simd --runtime none --importMemory --memoryBase 0",
"test": "yarn build:test && node --experimental-wasm-simd build/test/index.js",
"cover": "yarn build:test && nyc node --experimental-wasm-simd build/test/*.js && nyc report --reporter=lcov",
"clean": "rimraf *.js *.d.ts .nyc_output build coverage doc lib",
"doc": "node_modules/.bin/typedoc --mode modules --out doc --ignoreCompilerErrors src",
"pub": "yarn build:release && yarn publish --access public"
Expand Down
110 changes: 106 additions & 4 deletions packages/simd/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,103 @@
import { isString } from "@thi.ng/checks";

export interface SIMD {
/**
* WASM memory instance given to `init()`.
*/
memory: WebAssembly.Memory;
/**
* Float64 view of WASM memory.
*/
f64: Float64Array;
/**
* Float32 view of WASM memory.
*/
f32: Float32Array;
/**
* Uint32 view of WASM memory.
*/
u32: Uint32Array;
/**
* Int32 view of WASM memory.
*/
i32: Int32Array;
/**
* Uint16 of WASM memory.
*/
u16: Uint16Array;
/**
* Int16 view of WASM memory.
*/
i16: Int16Array;
/**
* Uint8 view of WASM memory.
*/
u8: Uint8Array;
/**
* Int8 view of WASM memory.
*/
i8: Int8Array;

/**
* Takes two densely packed vec2 AOS buffers `a` and `b`, computes their
* 2D dot products and stores results in `out`. Computes two results per
* iteration, hence `num` must be an even number or else the last vector
* will not be processed. `so` should be 1 for packed result buffer.
*
* `a` & `b` should be aligned to 16, `out` to multiples of 4.
*
* @param out
* @param a
* @param b
* @param num
* @param so
*/
// prettier-ignore
dot2_f32_aos(out: number, a: number, b: number, num: number, so: number): number;

/**
* Takes two vec4 AOS buffers, computes their dot products and stores
* results in `out`. `so` should be 1 for packed result buffer. `sa` and
* `sb` indicate the stride lengths (in floats) between each vector in
* each respective buffer and should be a multiple of 4.
*
* @param out
* @param a
* @param b
* @param num
* @param so
* @param sa
* @param sb
*/
// prettier-ignore
dot4_f32_aos(out: number, a: number, b: number, num: number, so: number, sa: number, sb: number): number;

// prettier-ignore
dot4(out: number, a: number, b: number, num: number, so: number, sa: number, sb: number): number;
dot4_f32_soa(out: number, a: number, b: number, num: number, sa: number, sb: number): number;

/**
* Takes three vec4 buffers, computes componentwise `a * b + c` and stores
* results in `out`. Both AOS / SOA layouts are supported, as long as
* all buffers are using the same layout.
*
* All strides must by multiples of 4. All pointers should be aligned to
* multiples of 16. Returns `out` pointer.
*
* @param out
* @param a
* @param b
* @param c
* @param num number of vec4
* @param so out element stride
* @param sa A element stride
* @param sb B element stride
* @param sc C element stride
*/
// prettier-ignore
madd4(out: number, a: number, b: number, c: number, num: number, so: number, sa: number, sb: number, sc: number): number;
madd4_f32(out: number, a: number, b: number, c: number, num: number, so: number, sa: number, sb: number, sc: number): number;

// prettier-ignore
maddN4(out: number, a: number, b: number, c: number, num: number, so: number, sa: number, sc: number): number;
maddn4_f32(out: number, a: number, b: number, c: number, num: number, so: number, sa: number, sc: number): number;
}

export const init = async (
Expand All @@ -33,5 +124,16 @@ export const init = async (
imports
);
}
return wasm.exports;
const buf = memory.buffer;
return <SIMD>{
...wasm.exports,
f32: new Float32Array(buf),
f64: new Float64Array(buf),
u32: new Uint32Array(buf),
i32: new Int32Array(buf),
u16: new Uint16Array(buf),
i16: new Int16Array(buf),
u8: new Uint8Array(buf),
i8: new Int8Array(buf)
};
};
80 changes: 75 additions & 5 deletions packages/simd/test/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,76 @@
// import * as assert from "assert";
// import * as s from "../src/index";
import { equiv } from "@thi.ng/equiv";
import * as fs from "fs";
import { init } from "../src";

describe("simd", () => {
it("tests pending");
});
const assertEqual = (res: any, exp: any, msg?: string) => {
if (!equiv(res, exp)) {
process.stderr.write(msg || `expected: ${exp}, got ${res}\n\n`);
process.exit(1);
}
};

(async () => {
const simd = await init(
fs.readFileSync("simd.wasm"),
new WebAssembly.Memory({ initial: 1 })
);

// dot2_aos
// prettier-ignore
simd.f32.set([
// a
1, 2, 3, 4,
// b
10, 20, 30, 40
]);
simd.dot2_f32_aos(1024, 0, 16, 2, 1);
assertEqual(simd.f32.slice(1024 / 4, 1024 / 4 + 2), [50, 250]);

// dot4_aos
// prettier-ignore
simd.f32.set([
// a
1, 2, 3, 4, 5, 6, 7, 8,
// b
10, 20, 30, 40, 50, 60, 70, 80
]);
simd.dot4_f32_aos(1024, 0, 32, 2, 1, 4, 4);
assertEqual(simd.f32.slice(1024 / 4, 1024 / 4 + 2), [300, 1740]);

// dot4_soa
// prettier-ignore
simd.f32.set([
// ax
1, 2, 3, 4,
// ay
1, 2, 3, 4,
// az
1, 2, 3, 4,
// aw
1, 2, 3, 4,
// bx
10, 10, 10, 10,
// by
20, 20, 20, 20,
// bz
30, 30, 30, 30,
// bw
40, 40, 40, 40
]);
simd.dot4_f32_soa(1024, 0, 64, 4, 4, 4);
assertEqual(simd.f32.slice(1024 / 4, 1024 / 4 + 4), [100, 200, 300, 400]);

// madd4
// prettier-ignore
simd.f32.set([
// a
1, 2, 3, 4, 5, 6, 7, 8,
// b
11, 11, 11, 11, 11, 11, 11, 11,
// c
100, 200, 300, 400, 500, 600, 700, 800
]);
simd.madd4_f32(1024, 0, 32, 64, 2, 4, 4, 4, 4);
// prettier-ignore
assertEqual(simd.f32.slice(1024 / 4, 1024 / 4 + 8), [111, 222, 333, 444, 555, 666, 777, 888]);
})();

0 comments on commit 50bc9fc

Please sign in to comment.