Skip to content

Commit

Permalink
Don't enable training by default; add cli feature flag
Browse files Browse the repository at this point in the history
Partially implements tracel-ai#729:

- The burn crate no longer enables the training module unless the
train feature is enabled.
- Added a new train-cli feature to enable the metric plotting/CLI
rendering.
  • Loading branch information
dae committed Aug 31, 2023
1 parent 760c9e1 commit 4ac12b8
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 25 deletions.
28 changes: 19 additions & 9 deletions burn-train/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,31 @@ readme = "README.md"
repository = "https://github.com/burn-rs/burn/tree/main/burn-train"
version = "0.9.0"

[features]
cli = [
"indicatif",
"rgb",
"terminal_size",
"textplots",
"nvml-wrapper",
"sysinfo",
"systemstat"
]

[dependencies]
burn-core = {path = "../burn-core", version = "0.9.0" }

# Console
indicatif = "0.17.5"
log = {workspace = true}
log4rs = {workspace = true}

# Metrics
nvml-wrapper = "0.9.0"
rgb = "0.8.36"
terminal_size = "0.2.6"
textplots = "0.8.0"
sysinfo = "0.29.8"
systemstat = "0.2.3"
# CLI
indicatif = { version = "0.17.5", optional = true }
nvml-wrapper = { version = "0.9.0", optional = true }
rgb = { version = "0.8.36", optional = true }
terminal_size = { version = "0.2.6", optional = true }
textplots = { version = "0.8.0", optional = true }
sysinfo = { version = "0.29.8", optional = true }
systemstat = { version = "0.2.3", optional = true }

# Utilities
derive-new = {workspace = true}
Expand Down
17 changes: 11 additions & 6 deletions burn-train/src/learner/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ use super::log::update_log_file;
use super::Learner;
use crate::checkpoint::{AsyncCheckpointer, Checkpointer, FileCheckpointer};
use crate::logger::{FileMetricLogger, MetricLogger};
use crate::metric::dashboard::cli::CLIDashboardRenderer;
use crate::metric::dashboard::CLIDashboardRenderer;
use crate::metric::dashboard::{Dashboard, DashboardRenderer, MetricWrapper, Metrics};
use crate::metric::{Adaptor, Metric, Numeric};
use crate::metric::{Adaptor, Metric};
use crate::AsyncTrainerCallback;
use burn_core::lr_scheduler::LRScheduler;
use burn_core::module::ADModule;
Expand Down Expand Up @@ -141,12 +141,13 @@ where
///
/// # Notes
///
/// Only [numeric](Numeric) metric can be displayed on a plot.
/// Only [numeric](crate::metric::Numeric) metric can be displayed on a plot.
/// If the same metric is also registered for the [validation split](Self::metric_valid_plot),
/// the same graph will be used for both.
#[cfg(feature = "cli")]
pub fn metric_train_plot<M>(mut self, metric: M) -> Self
where
M: Metric + Numeric + 'static,
M: Metric + crate::metric::Numeric + 'static,
T: Adaptor<M::Input>,
{
self.metrics
Expand All @@ -159,10 +160,14 @@ where
///
/// # Notes
///
/// Only [numeric](Numeric) metric can be displayed on a plot.
/// Only [numeric](crate::metric::Numeric) metric can be displayed on a plot.
/// If the same metric is also registered for the [training split](Self::metric_train_plot),
/// the same graph will be used for both.
pub fn metric_valid_plot<M: Metric + Numeric + 'static>(mut self, metric: M) -> Self
#[cfg(feature = "cli")]
pub fn metric_valid_plot<M: Metric + crate::metric::Numeric + 'static>(
mut self,
metric: M,
) -> Self
where
V: Adaptor<M::Input>,
{
Expand Down
25 changes: 25 additions & 0 deletions burn-train/src/metric/dashboard/cli_stub.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use crate::metric::dashboard::{DashboardMetricState, DashboardRenderer, TrainingProgress};

/// A simple renderer for when the cli feature is not enabled.
pub struct CLIDashboardRenderer;

impl CLIDashboardRenderer {
/// Create a new instance.
pub fn new() -> Self {
Self {}
}
}

impl DashboardRenderer for CLIDashboardRenderer {
fn update_train(&mut self, _state: DashboardMetricState) {}

fn update_valid(&mut self, _state: DashboardMetricState) {}

fn render_train(&mut self, item: TrainingProgress) {
dbg!(item);
}

fn render_valid(&mut self, item: TrainingProgress) {
dbg!(item);
}
}
10 changes: 9 additions & 1 deletion burn-train/src/metric/dashboard/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
/// Command line interface module for the dashboard.
pub mod cli;
#[cfg(feature = "cli")]
mod cli;
#[cfg(not(feature = "cli"))]
mod cli_stub;

mod base;
mod plot;

pub use base::*;
pub use plot::*;

#[cfg(feature = "cli")]
pub use cli::CLIDashboardRenderer;
#[cfg(not(feature = "cli"))]
pub use cli_stub::CLIDashboardRenderer;
19 changes: 15 additions & 4 deletions burn-train/src/metric/dashboard/plot.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
use rgb::RGB8;
use terminal_size::{terminal_size, Height, Width};
use textplots::{Chart, ColorPlot, Shape};

/// Text plot.
pub struct TextPlot {
train: Vec<(f32, f32)>,
Expand Down Expand Up @@ -108,7 +104,12 @@ impl TextPlot {
/// # Returns
///
/// The rendered text plot.
#[cfg(feature = "cli")]
pub fn render(&self) -> String {
use rgb::RGB8;
use terminal_size::{terminal_size, Height, Width};
use textplots::{Chart, ColorPlot, Shape};

let train_color = RGB8::new(255, 140, 140);
let valid_color = RGB8::new(140, 140, 255);

Expand Down Expand Up @@ -146,4 +147,14 @@ impl TextPlot {
.linecolorplot(&Shape::Lines(&self.valid), valid_color)
.to_string()
}

/// Renders the text plot.
///
/// # Returns
///
/// The rendered text plot.
#[cfg(not(feature = "cli"))]
pub fn render(&self) -> String {
panic!("metrics feature not enabled on burn-train")
}
}
10 changes: 10 additions & 0 deletions burn-train/src/metric/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,30 @@ pub mod state;

mod acc;
mod base;
#[cfg(feature = "cli")]
mod cpu_temp;
#[cfg(feature = "cli")]
mod cpu_use;
#[cfg(feature = "cli")]
mod cuda;
#[cfg(feature = "cli")]
mod gpu_temp;
mod learning_rate;
mod loss;
#[cfg(feature = "cli")]
mod memory_use;

pub use acc::*;
pub use base::*;
#[cfg(feature = "cli")]
pub use cpu_temp::*;
#[cfg(feature = "cli")]
pub use cpu_use::*;
#[cfg(feature = "cli")]
pub use cuda::*;
#[cfg(feature = "cli")]
pub use gpu_temp::*;
pub use learning_rate::*;
pub use loss::*;
#[cfg(feature = "cli")]
pub use memory_use::*;
6 changes: 4 additions & 2 deletions burn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@ repository = "https://github.com/burn-rs/burn"
version = "0.9.0"

[features]
default = ["std", "train"]
default = ["std"]
experimental-named-tensor = ["burn-core/experimental-named-tensor"]
std = [
"burn-core/std",
]
train = ["std", "burn-train"] # Training requires std
# Training requires std
train = ["std", "burn-train"]
train-cli = ["train", "burn-train/cli"]

[dependencies]

Expand Down
3 changes: 3 additions & 0 deletions examples/guide/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ name = "guide"
publish = false
version = "0.9.0"

[features]
default = ["burn/train-cli"]

[dependencies]
burn = {path = "../../burn"}
burn-autodiff = {path = "../../burn-autodiff"}
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ publish = false
version = "0.9.0"

[features]
default = []
default = ["burn/train-cli"]
ndarray = ["burn-ndarray"]
ndarray-blas-accelerate = ["burn-ndarray/blas-accelerate"]
ndarray-blas-netlib = ["burn-ndarray/blas-netlib"]
Expand Down
2 changes: 1 addition & 1 deletion examples/onnx-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ publish = false
version = "0.9.0"

[features]
default = []
default = ["burn/train-cli"]

[dependencies]
burn = {path = "../../burn"}
Expand Down
1 change: 1 addition & 0 deletions examples/text-classification/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ publish = false
version = "0.9.0"

[features]
default = ["burn/train-cli"]
f16 = []
ndarray = ["burn-ndarray"]
ndarray-blas-accelerate = ["burn-ndarray/blas-accelerate"]
Expand Down
2 changes: 1 addition & 1 deletion examples/text-generation/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ publish = false
version = "0.9.0"

[features]
default = []
default = ["burn/train-cli"]
f16 = []

[dependencies]
Expand Down

0 comments on commit 4ac12b8

Please sign in to comment.