Skip to content

Commit

Permalink
Refactor feature flags (tracel-ai#1025)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Dec 1, 2023
1 parent 1f18cf4 commit f6d14f1
Show file tree
Hide file tree
Showing 16 changed files with 76 additions and 73 deletions.
10 changes: 5 additions & 5 deletions backend-comparison/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ version = "0.11.0"
default = ["std"]
std = []
candle-cpu = ["burn/candle"]
candle-cuda = ["burn/candle-cuda"]
candle-accelerate = ["burn/candle-accelerate"]
candle-cuda = ["burn/candle", "burn/cuda"]
candle-accelerate = ["burn/candle", "burn/accelerate"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"]
ndarray-blas-netlib = ["burn/ndarray-blas-netlib"]
ndarray-blas-openblas = ["burn/ndarray-blas-openblas"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
Expand Down
2 changes: 2 additions & 0 deletions burn-candle/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ repository = "https://github.com/burn-rs/burn/tree/main/burn-candle"
version = "0.11.0"

[features]
default = ["std"]
std = []
cuda = ["candle-core/cuda"]
accelerate = ["candle-core/accelerate"]

Expand Down
55 changes: 29 additions & 26 deletions burn-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,19 @@ repository = "https://github.com/burn-rs/burn/tree/main/burn-core"
version = "0.11.0"

[features]
default = ["std", "dataset-minimal"]
default = [
"std",
"dataset",
"burn-ndarray?/default",
"burn-candle?/default",
"burn-wgpu?/default",
"burn-fusion?/default",
"burn-tch?/default",
"burn-dataset?/default",
"burn-common/default",
"burn-tensor/default",
]
std = [
"burn-common/std",
"burn-tensor/std",
"flate2",
"log",
"rand/std",
Expand All @@ -23,47 +32,41 @@ std = [
"serde_json/std",
"bincode/std",
"half/std",
"burn-ndarray?/std",
"burn-candle?/std",
"burn-wgpu?/std",
"burn-fusion?/std",
"burn-common/std",
"burn-tensor/std",
]
dataset = ["burn-dataset/default"]
dataset-minimal = ["burn-dataset"]
dataset-sqlite = ["burn-dataset/sqlite"]
dataset-sqlite-bundled = ["burn-dataset/sqlite-bundled"]
dataset = ["burn-dataset"]
sqlite = ["burn-dataset?/sqlite"]
sqlite-bundled = ["burn-dataset?/sqlite-bundled"]

wasm-sync = ["burn-tensor/wasm-sync", "burn-common/wasm-sync"]

# Backend
autodiff = ["burn-autodiff"]
fusion = ["burn-fusion", "burn-wgpu?/fusion"]

ndarray = ["__ndarray", "burn-ndarray/default"]
ndarray-no-std = ["__ndarray", "burn-ndarray"]
ndarray-blas-accelerate = [
"__ndarray",
"ndarray",
"burn-ndarray/blas-accelerate",
]
ndarray-blas-netlib = ["__ndarray", "ndarray", "burn-ndarray/blas-netlib"]
ndarray-blas-openblas = ["__ndarray", "ndarray", "burn-ndarray/blas-openblas"]
ndarray-blas-openblas-system = [
"__ndarray",
"ndarray",
"burn-ndarray/blas-openblas-system",
]
__ndarray = [] # Internal flag to know when one ndarray feature is enabled.
## Backend features
cuda = ["burn-candle?/cuda"]
accelerate = ["burn-candle?/accelerate", "burn-ndarray?/blas-accelerate"]
openblas = ["burn-ndarray?/blas-openblas"]
openblas-system = ["burn-ndarray?/blas-openblas-system"]
blas-netlib = ["burn-ndarray?/blas-netlib"]

ndarray = ["burn-ndarray"]
tch = ["burn-tch"]

candle = ["burn-candle"]
candle-cuda = ["candle", "burn-candle/cuda"]
candle-accelerate = ["candle", "burn-candle/accelerate"]
wgpu = ["burn-wgpu"]

# Serialization formats
experimental-named-tensor = ["burn-tensor/experimental-named-tensor"]

test-tch = ["tch"] # To use tch during testing, default uses ndarray.
test-wgpu = ["wgpu"] # To use wgpu during testing, default uses ndarray.

wgpu = ["burn-wgpu/default"]

[dependencies]

Expand Down
4 changes: 2 additions & 2 deletions burn-core/src/backend.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#[cfg(feature = "__ndarray")]
#[cfg(feature = "ndarray")]
pub use burn_ndarray as ndarray;

#[cfg(feature = "__ndarray")]
#[cfg(feature = "ndarray")]
pub use ndarray::NdArray;

#[cfg(feature = "autodiff")]
Expand Down
2 changes: 2 additions & 0 deletions burn-fusion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ repository = "https://github.com/burn-rs/burn/tree/main/burn-fusion"
version = "0.11.0"

[features]
default = ["std"]
std = []

[dependencies]
burn-tensor = {path = "../burn-tensor", version = "0.11.0", default-features = false }
Expand Down
1 change: 1 addition & 0 deletions burn-tch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ repository = "https://github.com/burn-rs/burn/tree/main/burn-tch"
version = "0.11.0"

[features]
default = []
doc = ["tch/doc-only"]

[dependencies]
Expand Down
3 changes: 2 additions & 1 deletion burn-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ repository = "https://github.com/burn-rs/burn/tree/main/burn-wgpu"
version = "0.11.0"

[features]
default = ["autotune"]
default = ["autotune", "std"]
std = []
autotune = []
fusion = ["burn-fusion"]

Expand Down
38 changes: 17 additions & 21 deletions burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,46 +12,42 @@ version = "0.11.0"
rust-version = "1.71"

[features]
default = ["burn-core/default"]
default = ["burn-core/default", "burn-train?/default"]
std = ["burn-core/std"]

# Training with full features
train = ["burn-train/default", "autodiff", "dataset"]

# Useful when targeting WASM and not using WGPU.
wasm-sync = ["burn-core/wasm-sync"]

## Include nothing
train-minimal = ["burn-train"]
train = ["burn-train", "autodiff", "dataset"]

## Includes the Text UI (progress bars, metric plots)
train-tui = ["burn-train/tui"]
tui = ["burn-train?/tui"]

## Includes system info metrics (CPU/GPU usage, etc)
train-metrics = ["burn-train/metrics"]
metrics = ["burn-train?/metrics"]

# Useful when targeting WASM and not using WGPU.
wasm-sync = ["burn-core/wasm-sync"]

# Datasets
dataset = ["burn-core/dataset"]
dataset-minimal = ["burn-core/dataset-minimal"]
dataset-sqlite = ["burn-core/dataset-sqlite"]
dataset-sqlite-bundled = ["burn-core/dataset-sqlite-bundled"]

sqlite = ["burn-core/sqlite"]
sqlite-bundled = ["burn-core/sqlite-bundled"]

# Backends
autodiff = ["burn-core/autodiff"]
fusion = ["burn-core/fusion"]

ndarray = ["burn-core/ndarray"]
ndarray-no-std = ["burn-core/ndarray-no-std"]
ndarray-blas-accelerate = ["burn-core/ndarray-blas-accelerate"]
ndarray-blas-netlib = ["burn-core/ndarray-blas-netlib"]
ndarray-blas-openblas = ["burn-core/ndarray-blas-openblas"]
ndarray-blas-openblas-system = ["burn-core/ndarray-blas-openblas-system"]
## Backend features
cuda = ["burn-core/cuda"]
accelerate = ["burn-core/accelerate"]
openblas = ["burn-core/openblas"]
openblas-system = ["burn-core/openblas-system"]
blas-netlib = ["burn-core/blas-netlib"]

ndarray = ["burn-core/ndarray"]
wgpu = ["burn-core/wgpu"]
tch = ["burn-core/tch"]
candle = ["burn-core/candle"]
candle-cuda = ["burn-core/candle-cuda"]
candle-accelerate = ["burn-core/candle-accelerate"]

# Experimental
experimental-named-tensor = ["burn-core/experimental-named-tensor"]
Expand Down
4 changes: 2 additions & 2 deletions examples/custom-renderer/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ publish = false
version = "0.11.0"

[dependencies]
burn = {path = "../../burn", features=["autodiff", "wgpu", "train-minimal", "dataset"]}
guide = {path = "../guide", default-features=false, features=["train-minimal"]}
burn = {path = "../../burn", features=["autodiff", "wgpu", "train", "dataset"], default-features=false}
guide = {path = "../guide", default-features=false}

# Serialization
log = {workspace = true}
Expand Down
6 changes: 2 additions & 4 deletions examples/guide/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@ publish = false
version = "0.11.0"

[features]
default = ["train"]
train = ["burn/train"]
train-minimal = ["burn/train-minimal"]
default = ["burn/default"]

[dependencies]
burn = {path = "../../burn", features = ["wgpu"]}
burn = {path = "../../burn", features = ["wgpu", "train"]}

# Serialization
log = {workspace = true}
Expand Down
2 changes: 1 addition & 1 deletion examples/image-classification-web/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ half_precision = []

[dependencies]
burn = { path = "../../burn", version = "0.11.0", default-features = false, features = [
"ndarray-no-std",
"ndarray",
] }
burn-wgpu = { path = "../../burn-wgpu", version = "0.11.0", default-features = false }
burn-candle = { path = "../../burn-candle", version = "0.11.0", default-features = false }
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist-inference-web/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ crate-type = ["cdylib"]
[features]
default = ["ndarray"]

ndarray = ["burn/ndarray-no-std"]
ndarray = ["burn/ndarray"]
wgpu = ["burn/wgpu"]

[dependencies]
Expand Down
8 changes: 4 additions & 4 deletions examples/mnist/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ publish = false
version = "0.11.0"

[features]
default = ["burn/dataset-sqlite-bundled"]
default = ["burn/dataset", "burn/sqlite-bundled"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"]
ndarray-blas-netlib = ["burn/ndarray-blas-netlib"]
ndarray-blas-openblas = ["burn/ndarray-blas-openblas"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
Expand Down
2 changes: 1 addition & 1 deletion examples/onnx-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ default = ["embedded-model"]
embedded-model = []

[dependencies]
burn = { path = "../../burn", features = ["ndarray", "dataset-sqlite-bundled"] }
burn = { path = "../../burn", features = ["ndarray", "dataset", "sqlite-bundled"] }
serde = { workspace = true }

[build-dependencies]
Expand Down
8 changes: 4 additions & 4 deletions examples/text-classification/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ publish = false
version = "0.11.0"

[features]
default = ["burn/dataset-sqlite-bundled"]
default = ["burn/dataset", "burn/sqlite-bundled"]
f16 = []
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"]
ndarray-blas-netlib = ["burn/ndarray-blas-netlib"]
ndarray-blas-openblas = ["burn/ndarray-blas-openblas"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
ndarray-blas-openblas = ["burn/ndarray", "burn/openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
Expand Down
2 changes: 1 addition & 1 deletion examples/text-generation/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ publish = false
version = "0.11.0"

[features]
default = ["burn/dataset-sqlite-bundled"]
default = ["burn/dataset", "burn/sqlite-bundled"]
f16 = []

[dependencies]
Expand Down

0 comments on commit f6d14f1

Please sign in to comment.