Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Am/chore/wasm support #39

Merged
merged 3 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Accessed by wasm-bindgen when testing for the wasm target
[target.wasm32-unknown-unknown]
runner = 'wasm-bindgen-test-runner'
10 changes: 10 additions & 0 deletions .github/workflows/cargo_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,13 @@ jobs:
run: |
make test_no_std_nightly
make test_no_std_nightly FFT128_SUPPORT=ON

cargo-tests-node-js:
runs-on: "ubuntu-latest"
steps:
- uses: actions/checkout@ac593985615ec2ede58e132d2e21d2b1cbd6127c

- name: Test node js
run: |
make install_node
make test_node_js
2 changes: 1 addition & 1 deletion .github/workflows/check_commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- name: Check first line
uses: gsactions/commit-message-checker@16fa2d5de096ae0d35626443bcd24f1e756cafee
with:
pattern: '^((feat|fix|chore|refactor|style|test|docs|doc)\(\w+\)\:) .+$'
pattern: '^((feat|fix|chore|refactor|style|test|docs|doc)(\([\w\-_]+\))?\!?\:) .+$'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify the case you want to avoid ? If I undertand correctly you allow hypen and underscore within the parenthesis but not the exclamation mark right ?
What was the issue you've encountered ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's something we did with Nico on the TFHE-rs repo to allow hypens, make the content of the parentheses optional and the ! is to allow indicating breaking change as conventional commits permits

flags: "gs"
error: 'Your first line has to contain a commit type and scope like "feat(my_feature): msg".'
excludeDescription: "true" # optional: this excludes the description body of a pull request
Expand Down
18 changes: 14 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "concrete-fft"
version = "0.5.0"
version = "0.5.1"
edition = "2021"
authors = ["sarah el kazdadi <sarah.elkazdadi@zama.ai>"]
description = "Concrete-FFT is a pure Rust high performance fast Fourier transform library."
Expand All @@ -18,6 +18,9 @@ num-complex = { version = "0.4", features = ["bytemuck"] }
pulp = { version = "0.18.22", default-features = false }
serde = { version = "1.0", optional = true, default-features = false }

[target.'cfg(target_arch = "wasm32")'.dependencies]
js-sys = "0.3"

[features]
default = ["std"]
fft128 = []
Expand All @@ -26,17 +29,24 @@ std = ["pulp/std"]
serde = ["dep:serde", "num-complex/serde"]

[dev-dependencies]
criterion = "0.4"
rustfft = "6.0"
fftw-sys = { version = "0.6", default-features = false, features = ["system"] }
rand = "0.8"
bincode = "1.3"
more-asserts = "0.3.1"
serde_json = "1.0.96"

[target.'cfg(not(target_os = "windows"))'.dev-dependencies]
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
wasm-bindgen-test = "0.3"
wasm-bindgen = "0.2.86"
getrandom = { version = "0.2", features = ["js"] }

[target.'cfg(all(not(target_os = "windows"), not(target_arch = "wasm32")))'.dev-dependencies]
rug = "1.19.1"

[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
criterion = "0.4"
fftw-sys = { version = "0.6", default-features = false, features = ["system"] }

[[bench]]
name = "fft"
harness = false
Expand Down
43 changes: 41 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ CARGO_RS_CHECK_TOOLCHAIN:=+$(RS_CHECK_TOOLCHAIN)
RS_BUILD_TOOLCHAIN:=stable
CARGO_RS_BUILD_TOOLCHAIN:=+$(RS_BUILD_TOOLCHAIN)
MIN_RUST_VERSION:=1.65
WASM_BINDGEN_VERSION:=$(shell grep '^wasm-bindgen[[:space:]]*=' Cargo.toml | cut -d '=' -f 2 | xargs)
NODE_VERSION=22.6
AVX512_SUPPORT?=OFF
FFT128_SUPPORT?=OFF
# This is done to avoid forgetting it, we still precise the RUSTFLAGS in the commands to be able to
Expand Down Expand Up @@ -47,6 +49,35 @@ install_rs_build_toolchain:
( echo "Unable to install $(RS_BUILD_TOOLCHAIN) toolchain, check your rustup installation. \
Rustup can be downloaded at https://rustup.rs/" && exit 1 )

.PHONY: install_build_wasm32_target # Install the wasm32 toolchain used for builds
install_build_wasm32_target: install_rs_build_toolchain
rustup +$(RS_BUILD_TOOLCHAIN) target add wasm32-unknown-unknown || \
( echo "Unable to install wasm32-unknown-unknown target toolchain, check your rustup installation. \
Rustup can be downloaded at https://rustup.rs/" && exit 1 )

# The installation uses the ^ symbol because we need the matching version of wasm-bindgen in the
# Cargo.toml, as we don't lock those dependencies, this allows to get the matching CLI
.PHONY: install_wasm_bindgen_cli # Install wasm-bindgen-cli to get access to the test runner
install_wasm_bindgen_cli: install_rs_build_toolchain
cargo +$(RS_BUILD_TOOLCHAIN) install --locked wasm-bindgen-cli --version ^$(WASM_BINDGEN_VERSION)

.PHONY: install_node # Install last version of NodeJS via nvm
install_node:
curl -o nvm_install.sh https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.3/install.sh
@echo "2ed5e94ba12434370f0358800deb69f514e8bce90f13beb0e1b241d42c6abafd nvm_install.sh" > nvm_checksum
@sha256sum -c nvm_checksum
@rm nvm_checksum
$(SHELL) nvm_install.sh
@rm nvm_install.sh
source ~/.bashrc
$(SHELL) -i -c 'nvm install $(NODE_VERSION)' || \
( echo "Unable to install node, unknown error." && exit 1 )

.PHONY: check_nvm_installed # Check if Node Version Manager is installed
check_nvm_installed:
@source ~/.nvm/nvm.sh && nvm --version > /dev/null 2>&1 || \
( echo "Unable to locate Node. Run 'make install_node'" && exit 1 )

.PHONY: fmt # Format rust code
fmt: install_rs_check_toolchain
cargo "$(CARGO_RS_CHECK_TOOLCHAIN)" fmt
Expand Down Expand Up @@ -85,7 +116,7 @@ test: install_rs_build_toolchain

.PHONY: test_serde
test_serde: install_rs_build_toolchain
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release
RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release \
--features=serde

.PHONY: test_nightly
Expand All @@ -105,8 +136,16 @@ test_no_std_nightly: install_rs_check_toolchain
--no-default-features \
--features=nightly,$(FFT128_FEATURE)

.PHONY: test_node_js
test_node_js: install_rs_build_toolchain install_build_wasm32_target install_wasm_bindgen_cli check_nvm_installed
source ~/.nvm/nvm.sh && \
nvm install $(NODE_VERSION) && \
nvm use $(NODE_VERSION) && \
RUSTFLAGS="" cargo $(CARGO_RS_BUILD_TOOLCHAIN) test --release \
--features=serde --target wasm32-unknown-unknown

.PHONY: test_all
test_all: test test_serde test_nightly test_no_std test_no_std_nightly
test_all: test test_serde test_nightly test_no_std test_no_std_nightly test_node_js

.PHONY: doc # Build rust doc
doc: install_rs_check_toolchain
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ macro_rules! izip {
mod fft_simd;
mod nat;

#[cfg(feature = "std")]
pub(crate) mod time;

#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
mod x86;

Expand Down
12 changes: 10 additions & 2 deletions src/ordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ fn measure_n_runs(
let (scratch, _) = stack.make_aligned_raw::<c64>(n, CACHELINE_ALIGN);
let [fwd, _] = get_fn_ptr(algo, n);

use std::time::Instant;
// For wasm we have a dedicated implementation going through js-sys
use crate::time::Instant;
let now = Instant::now();

for _ in 0..n_runs {
Expand Down Expand Up @@ -101,7 +102,13 @@ pub(crate) fn measure_fastest(
stack: PodStack,
) -> (FftAlgo, Duration) {
const N_ALGOS: usize = 8;
const MIN_DURATION: Duration = Duration::from_millis(1);
const MIN_DURATION: Duration = if cfg!(target_arch = "wasm32") {
// This is to account for the fact the js-sys based time measurement has a resolution of 1ms
// on chrome, this will slow down the fft benchmarking somewhat, but it's barely noticeable
Duration::from_millis(10)
} else {
Duration::from_millis(1)
};

assert!(n.is_power_of_two());

Expand Down Expand Up @@ -443,6 +450,7 @@ mod tests {
}
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_fft() {
test_fft_simd(crate::fft_simd::Scalar);
Expand Down
11 changes: 11 additions & 0 deletions src/time/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
//! The standard API for Instant is not available in Wasm runtimes.
//! This module replaces the Instant type from std to a custom implementation.

#[cfg(target_arch = "wasm32")]
mod wasm;

#[cfg(target_arch = "wasm32")]
pub(crate) use wasm::Instant;

#[cfg(not(target_arch = "wasm32"))]
pub(crate) use std::time::Instant;
18 changes: 18 additions & 0 deletions src/time/wasm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
pub(crate) struct Instant {
start: f64,
}

impl Instant {
/// This function only has a millisecond resolution on some platforms like the chrome browser
pub fn now() -> Self {
let now = js_sys::Date::new_0().get_time();
Self { start: now }
}

/// This function only has a millisecond resolution on some platforms like the chrome browser,
/// which means it can easily return 0 when called on quick code
pub fn elapsed(&self) -> core::time::Duration {
let now = js_sys::Date::new_0().get_time();
core::time::Duration::from_millis((now - self.start) as u64)
}
}
8 changes: 7 additions & 1 deletion src/unordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,8 @@ fn measure_fastest(

let n_runs = n_runs.ceil() as u32;

use std::time::Instant;
// For wasm we have a dedicated implementation going through js-sys
use crate::time::Instant;
let now = Instant::now();
for _ in 0..n_runs {
fwd_depth(
Expand Down Expand Up @@ -1067,6 +1068,7 @@ mod tests {

extern crate alloc;

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_fwd() {
for n in [128, 256, 512, 1024] {
Expand Down Expand Up @@ -1101,6 +1103,7 @@ mod tests {
}
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_fwd_monomial() {
for n in [256, 512, 1024] {
Expand Down Expand Up @@ -1133,6 +1136,7 @@ mod tests {
}
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_roundtrip() {
for n in [32, 64, 256, 512, 1024] {
Expand Down Expand Up @@ -1167,6 +1171,7 @@ mod tests {
}
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_equivalency() {
use num_complex::Complex;
Expand Down Expand Up @@ -9401,6 +9406,7 @@ mod tests_serde {

extern crate alloc;

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test::wasm_bindgen_test)]
#[test]
fn test_serde() {
for n in [64, 128, 256, 512, 1024] {
Expand Down
Loading