diff --git a/.github/workflows/precompile.yml b/.github/workflows/precompile.yml index 667aa21..ea027e0 100644 --- a/.github/workflows/precompile.yml +++ b/.github/workflows/precompile.yml @@ -1,6 +1,8 @@ name: precompile -on: push +on: + - push + - workflow_dispatch jobs: linux: @@ -52,7 +54,7 @@ jobs: MIX_ENV: "prod" strategy: matrix: - runner: ["macos-latest", "self-hosted"] + runner: ["macos-latest", "exgboost-m2-runner"] otp: ["25.0", "26.0"] elixir: ["1.14.5"] steps: diff --git a/.gitignore b/.gitignore index f214385..15fd0cd 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ erl_crash.dump .elixir_ls/ .tool-versions .vscode/ -checksum.exs \ No newline at end of file +checksum.exs +.DS_Store \ No newline at end of file diff --git a/Makefile b/Makefile index 4520e18..41c14c3 100644 --- a/Makefile +++ b/Makefile @@ -63,6 +63,7 @@ $(XGBOOST_LIB_DIR_FLAG): git fetch --depth 1 --recurse-submodules origin $(XGBOOST_GIT_REV) && \ git checkout FETCH_HEAD && \ git submodule update --init --recursive && \ + sed 's|learner_parameters\["generic_param"\] = ToJson(ctx_);|&\nlearner_parameters\["default_metric"\] = String(obj_->DefaultEvalMetric());|' src/learner.cc > src/learner.cc.tmp && mv src/learner.cc.tmp src/learner.cc && \ cmake -DCMAKE_INSTALL_PREFIX=$(XGBOOST_LIB_DIR) -B build . $(CMAKE_FLAGS) && \ make -C build -j1 install touch $(XGBOOST_LIB_DIR_FLAG) diff --git a/README.md b/README.md index 8da4acd..0a8b04c 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ billions of examples. ```elixir def deps do [ - {:exgboost, "~> 0.3"} + {:exgboost, "~> 0.5"} ] end ``` diff --git a/lib/exgboost.ex b/lib/exgboost.ex index 67f9636..fffbd76 100644 --- a/lib/exgboost.ex +++ b/lib/exgboost.ex @@ -16,7 +16,7 @@ defmodule EXGBoost do ```elixir def deps do [ - {:exgboost, "~> 0.4"} + {:exgboost, "~> 0.5"} ] end ``` @@ -92,7 +92,7 @@ defmodule EXGBoost do preds = EXGBoost.train(X, y) |> EXGBoost.predict(X) ``` - ## Serliaztion + ## Serialization A Booster can be serialized to a file using `EXGBoost.write_*` and loaded from a file using `EXGBoost.read_*`. The file format can be specified using the `:format` option @@ -113,6 +113,34 @@ defmodule EXGBoost do - `config` - Save the configuration only. - `weights` - Save the model parameters only. Use this when you want to save the model to a format that can be ingested by other XGBoost APIs. - `model` - Save both the model parameters and the configuration. + + ## Plotting + + `EXGBoost.plot_tree/2` is the primary entry point for plotting a tree from a trained model. + It accepts an `EXGBoost.Booster` struct (which is the output of `EXGBoost.train/2`). + `EXGBoost.plot_tree/2` returns a VegaLite spec that can be rendered in a notebook or saved to a file. + `EXGBoost.plot_tree/2` also accepts a keyword list of options that can be used to configure the plotting process. + + See `EXGBoost.Plotting` for more detail on plotting. + + You can see available styles by running `EXGBoost.Plotting.get_styles()` or refer to the `EXGBoost.Plotting.Styles` + documentation for a gallery of the styles. + + ## Kino & Livebook Integration + + `EXGBoost` integrates with [Kino](https://hexdocs.pm/kino/Kino.html) and [Livebook](https://livebook.dev/) + to provide a rich interactive experience for data scientists. + + EXGBoost implements the `Kino.Render` protocol for `EXGBoost.Booster` structs. This allows you to render + a Booster in a Livebook notebook. Under the hood, `EXGBoost` uses [Vega-Lite](https://vega.github.io/vega-lite/) + and [Kino Vega-Lite](https://hexdocs.pm/kino_vega_lite/Kino.VegaLite.html) to render the Booster. + + See the [`Plotting in EXGBoost`](notebooks/plotting.livemd) Notebook for an example of how to use `EXGBoost` with `Kino` and `Livebook`. + + ## Examples + + See the example Notebooks in the left sidebar (under the `Pages` tab) for more examples and tutorials + on how to use EXGBoost. """ alias EXGBoost.ArrayInterface @@ -121,6 +149,7 @@ defmodule EXGBoost do alias EXGBoost.DMatrix alias EXGBoost.ProxyDMatrix alias EXGBoost.Training + alias EXGBoost.Plotting @doc """ Check the build information of the xgboost library. @@ -128,6 +157,7 @@ defmodule EXGBoost do Returns a map containing information about the build. """ @spec xgboost_build_info() :: map() + @doc type: :system def xgboost_build_info, do: EXGBoost.NIF.xgboost_build_info() |> Internal.unwrap!() |> Jason.decode!() @@ -137,6 +167,7 @@ defmodule EXGBoost do Returns a 3-tuple in the form of `{major, minor, patch}`. """ @spec xgboost_version() :: {integer(), integer(), integer()} | {:error, String.t()} + @doc type: :system def xgboost_version, do: EXGBoost.NIF.xgboost_version() |> Internal.unwrap!() @doc """ @@ -147,6 +178,7 @@ defmodule EXGBoost do for the full list of parameters supported in the global configuration. """ @spec set_config(map()) :: :ok | {:error, String.t()} + @doc type: :system def set_config(%{} = config) do config = EXGBoost.Parameters.validate_global!(config) EXGBoost.NIF.set_global_config(Jason.encode!(config)) |> Internal.unwrap!() @@ -160,6 +192,7 @@ defmodule EXGBoost do for the full list of parameters supported in the global configuration. """ @spec get_config() :: map() + @doc type: :system def get_config do EXGBoost.NIF.get_global_config() |> Internal.unwrap!() |> Jason.decode!() end @@ -208,10 +241,11 @@ defmodule EXGBoost do * `opts` - Refer to `EXGBoost.Parameters` for the full list of options. """ @spec train(Nx.Tensor.t(), Nx.Tensor.t(), Keyword.t()) :: EXGBoost.Booster.t() + @doc type: :train_pred def train(x, y, opts \\ []) do x = Nx.concatenate(x) y = Nx.concatenate(y) - {dmat_opts, opts} = Keyword.split(opts, Internal.dmatrix_feature_opts()) + dmat_opts = Keyword.take(opts, Internal.dmatrix_feature_opts()) dmat = DMatrix.from_tensor(x, y, Keyword.put_new(dmat_opts, :format, :dense)) Training.train(dmat, opts) end @@ -272,6 +306,7 @@ defmodule EXGBoost do Returns an Nx.Tensor containing the predictions. """ + @doc type: :train_pred def predict(%Booster{} = bst, x, opts \\ []) do x = Nx.concatenate(x) {dmat_opts, opts} = Keyword.split(opts, Internal.dmatrix_feature_opts()) @@ -302,6 +337,7 @@ defmodule EXGBoost do Returns an Nx.Tensor containing the predictions. """ + @doc type: :train_pred def inplace_predict(%Booster{} = boostr, data, opts \\ []) do opts = Keyword.validate!(opts, @@ -428,6 +464,7 @@ defmodule EXGBoost do ## Options #{NimbleOptions.docs(@write_schema)} """ + @doc type: :serialization @spec write_model(Booster.t(), String.t()) :: :ok | {:error, String.t()} def write_model(%Booster{} = booster, path, opts \\ []) do opts = NimbleOptions.validate!(opts, @write_schema) @@ -437,6 +474,7 @@ defmodule EXGBoost do @doc """ Read a model from a file and return the Booster. """ + @doc type: :serialization @spec read_model(String.t()) :: EXGBoost.Booster.t() def read_model(path) do EXGBoost.Booster.load(path, deserialize: :model) @@ -449,6 +487,7 @@ defmodule EXGBoost do #{NimbleOptions.docs(@dump_schema)} """ @spec dump_model(Booster.t()) :: binary() + @doc type: :serialization def dump_model(%Booster{} = booster, opts \\ []) do opts = NimbleOptions.validate!(opts, @dump_schema) EXGBoost.Booster.save(booster, opts ++ [serialize: :model, to: :buffer]) @@ -458,6 +497,7 @@ defmodule EXGBoost do Read a model from a buffer and return the Booster. """ @spec load_model(binary()) :: EXGBoost.Booster.t() + @doc type: :serialization def load_model(buffer) do EXGBoost.Booster.load(buffer, deserialize: :model, from: :buffer) end @@ -469,6 +509,7 @@ defmodule EXGBoost do #{NimbleOptions.docs(@write_schema)} """ @spec write_config(Booster.t(), String.t()) :: :ok | {:error, String.t()} + @doc type: :serialization def write_config(%Booster{} = booster, path, opts \\ []) do opts = NimbleOptions.validate!(opts, @write_schema) EXGBoost.Booster.save(booster, opts ++ [path: path, serialize: :config]) @@ -481,6 +522,7 @@ defmodule EXGBoost do #{NimbleOptions.docs(@dump_schema)} """ @spec dump_config(Booster.t()) :: binary() + @doc type: :serialization def dump_config(%Booster{} = booster, opts \\ []) do opts = NimbleOptions.validate!(opts, @dump_schema) EXGBoost.Booster.save(booster, opts ++ [serialize: :config, to: :buffer]) @@ -493,6 +535,7 @@ defmodule EXGBoost do #{NimbleOptions.docs(@load_schema)} """ @spec read_config(String.t()) :: EXGBoost.Booster.t() + @doc type: :serialization def read_config(path, opts \\ []) do opts = NimbleOptions.validate!(opts, @load_schema) EXGBoost.Booster.load(path, opts ++ [deserialize: :config]) @@ -505,6 +548,7 @@ defmodule EXGBoost do #{NimbleOptions.docs(@load_schema)} """ @spec load_config(binary()) :: EXGBoost.Booster.t() + @doc type: :serialization def load_config(buffer, opts \\ []) do opts = NimbleOptions.validate!(opts, @load_schema) EXGBoost.Booster.load(buffer, opts ++ [deserialize: :config, from: :buffer]) @@ -517,6 +561,7 @@ defmodule EXGBoost do #{NimbleOptions.docs(@write_schema)} """ @spec write_weights(Booster.t(), String.t()) :: :ok | {:error, String.t()} + @doc type: :serialization def write_weights(%Booster{} = booster, path, opts \\ []) do opts = NimbleOptions.validate!(opts, @write_schema) EXGBoost.Booster.save(booster, opts ++ [path: path, serialize: :weights]) @@ -529,6 +574,7 @@ defmodule EXGBoost do #{NimbleOptions.docs(@dump_schema)} """ @spec dump_weights(Booster.t()) :: binary() + @doc type: :serialization def dump_weights(%Booster{} = booster, opts \\ []) do opts = NimbleOptions.validate!(opts, @dump_schema) EXGBoost.Booster.save(booster, opts ++ [serialize: :weights, to: :buffer]) @@ -538,6 +584,7 @@ defmodule EXGBoost do Read a model's trained parameters from a file and return the Booster. """ @spec read_weights(String.t()) :: EXGBoost.Booster.t() + @doc type: :serialization def read_weights(path) do EXGBoost.Booster.load(path, deserialize: :weights) end @@ -546,7 +593,30 @@ defmodule EXGBoost do Read a model's trained parameters from a buffer and return the Booster. """ @spec load_weights(binary()) :: EXGBoost.Booster.t() + @doc type: :serialization def load_weights(buffer) do EXGBoost.Booster.load(buffer, deserialize: :weights, from: :buffer) end + + @doc """ + Plot a tree from a Booster model and save it to a file. + + ## Options + * `:format` - the format to export the graphic as, must be either of: `:json`, `:html`, `:png`, `:svg`, `:pdf`. By default the format is inferred from the file extension. + * `:local_npm_prefix` - a relative path pointing to a local npm project directory where the necessary npm packages are installed. For instance, in Phoenix projects you may want to pass local_npm_prefix: "assets". By default the npm packages are searched for in the current directory and globally. + * `:path` - the path to save the graphic to. If not provided, the graphic is returned as a VegaLite spec. + * `:opts` - additional options to pass to `EXGBoost.Plotting.plot/2`. See `EXGBoost.Plotting` for more information. + """ + @doc type: :plotting + def plot_tree(booster, opts \\ []) do + {path, opts} = Keyword.pop(opts, :path) + {save_opts, opts} = Keyword.split(opts, [:format, :local_npm_prefix]) + vega = Plotting.plot(booster, opts) + + if path != nil do + VegaLite.Export.save!(vega, path, save_opts) + else + vega + end + end end diff --git a/lib/exgboost/booster.ex b/lib/exgboost/booster.ex index f069cdc..f60b8ed 100644 --- a/lib/exgboost/booster.ex +++ b/lib/exgboost/booster.ex @@ -163,9 +163,15 @@ defmodule EXGBoost.Booster do def booster(dmats, opts \\ []) def booster(dmats, opts) when is_list(dmats) do + {str_opts, opts} = Keyword.split(opts, Internal.dmatrix_str_feature_opts()) opts = EXGBoost.Parameters.validate!(opts) refs = Enum.map(dmats, & &1.ref) booster_ref = EXGBoost.NIF.booster_create(refs) |> Internal.unwrap!() + + Enum.each(str_opts, fn {key, value} -> + EXGBoost.NIF.booster_set_str_feature_info(booster_ref, Atom.to_string(key), value) + end) + set_params(%__MODULE__{ref: booster_ref}, opts) end @@ -174,9 +180,15 @@ defmodule EXGBoost.Booster do end def booster(%__MODULE__{} = bst, opts) do + {str_opts, opts} = Keyword.split(opts, Internal.dmatrix_str_feature_opts()) opts = EXGBoost.Parameters.validate!(opts) boostr_bytes = EXGBoost.NIF.booster_serialize_to_buffer(bst.ref) |> Internal.unwrap!() booster_ref = EXGBoost.NIF.booster_deserialize_from_buffer(boostr_bytes) |> Internal.unwrap!() + + Enum.each(str_opts, fn {key, value} -> + EXGBoost.NIF.booster_set_str_feature_info(booster_ref, Atom.to_string(key), value) + end) + set_params(%__MODULE__{ref: booster_ref}, opts) end diff --git a/lib/exgboost/nif.ex b/lib/exgboost/nif.ex index ba09dfc..bbed9e0 100644 --- a/lib/exgboost/nif.ex +++ b/lib/exgboost/nif.ex @@ -100,7 +100,6 @@ defmodule EXGBoost.NIF do def dmatrix_create_from_file(_file_uri, _silent), do: :erlang.nif_error(:not_implemented) - @since "0.4.0" def dmatrix_create_from_uri(_config), do: :erlang.nif_error(:not_implemented) @spec dmatrix_create_from_mat(binary, integer(), integer(), float()) :: diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex new file mode 100644 index 0000000..50719fb --- /dev/null +++ b/lib/exgboost/plotting.ex @@ -0,0 +1,1774 @@ +defmodule EXGBoost.Plotting do + use EXGBoost.Plotting.Style + + @doc """ + A light theme based on the [Solarized](https://ethanschoonover.com/solarized/) color palette + """ + style :solarized_light do + [ + # base3 + background: "#fdf6e3", + leaves: [ + # base01 + text: [fill: "#586e75", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + # base2, base1 + rect: [fill: "#eee8d5", stroke: "#93a1a1", strokeWidth: 1] + ], + splits: [ + # base01 + text: [fill: "#586e75", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + # base1, base00 + rect: [fill: "#93a1a1", stroke: "#657b83", strokeWidth: 1], + # base00 + children: [fill: "#657b83", stroke: "#657b83", strokeWidth: 1] + ], + yes: [ + # green + text: [fill: "#859900"], + # base00 + path: [stroke: "#657b83", strokeWidth: 1] + ], + no: [ + # red + text: [fill: "#dc322f"], + # base00 + path: [stroke: "#657b83", strokeWidth: 1] + ] + ] + end + + @doc """ + A dark theme based on the [Solarized](https://ethanschoonover.com/solarized/) color palette + """ + style :solarized_dark do + # base03 + [ + background: "#002b36", + leaves: [ + # base0 + text: [fill: "#839496", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + # base02, base01 + rect: [fill: "#073642", stroke: "#586e75", strokeWidth: 1] + ], + splits: [ + # base0 + text: [fill: "#839496", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + # base01, base00 + rect: [fill: "#586e75", stroke: "#657b83", strokeWidth: 1], + # base00 + children: [fill: "#657b83", stroke: "#657b83", strokeWidth: 1] + ], + yes: [ + # green + text: [fill: "#859900"], + # base00 + path: [stroke: "#657b83", strokeWidth: 1] + ], + no: [ + # red + text: [fill: "#dc322f"], + # base00 + path: [stroke: "#657b83", strokeWidth: 1] + ] + ] + end + + @doc """ + A light and playful theme + """ + style :playful_light do + [ + background: "#f0f0f0", + padding: 10, + leaves: [ + text: [fill: "#000", font_size: 12, font_style: "italic", font_weight: "bold"], + rect: [fill: "#e91e63", stroke: "#000", stroke_width: 1, radius: 5] + ], + splits: [ + text: [fill: "#000", font_size: 12, font_style: "normal", font_weight: "bold"], + children: [ + fill: "#000", + font_size: 12, + font_style: "normal", + font_weight: "bold" + ], + rect: [fill: "#8bc34a", stroke: "#000", stroke_width: 1, radius: 10] + ], + yes: [ + path: [stroke: "#4caf50", stroke_width: 2] + ], + no: [ + path: [stroke: "#f44336", stroke_width: 2] + ] + ] + end + + @doc """ + A dark and playful theme + """ + style :playful_dark do + [ + background: "#333", + padding: 10, + leaves: [ + text: [fill: "#fff", font_size: 12, font_style: "italic", font_weight: "bold"], + rect: [fill: "#e91e63", stroke: "#fff", stroke_width: 1, radius: 5] + ], + splits: [ + text: [fill: "#fff", font_size: 12, font_style: "normal", font_weight: "bold"], + rect: [fill: "#8bc34a", stroke: "#fff", stroke_width: 1, radius: 10] + ], + yes: [ + text: [fill: "#4caf50"], + path: [stroke: "#4caf50", stroke_width: 2] + ], + no: [ + text: [fill: "#f44336"], + path: [stroke: "#f44336", stroke_width: 2] + ] + ] + end + + @doc """ + A dark theme + """ + style :dark do + [ + background: "#333", + padding: 10, + leaves: [ + text: [fill: "#fff", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#666", stroke: "#fff", strokeWidth: 1] + ], + splits: [ + text: [fill: "#fff", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#444", stroke: "#fff", strokeWidth: 1], + children: [fill: "#fff", stroke: "#fff", strokeWidth: 1] + ] + ] + end + + @doc """ + A high contrast theme + """ + style :high_contrast do + [ + background: "#000", + padding: 10, + leaves: [ + text: [fill: "#fff", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#333", stroke: "#fff", strokeWidth: 1] + ], + splits: [ + text: [fill: "#fff", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#666", stroke: "#fff", strokeWidth: 1] + ] + ] + end + + @doc """ + A light theme + """ + style :light do + [ + background: "#f0f0f0", + padding: 10, + leaves: [ + text: [fill: "#000", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#ddd", stroke: "#000", strokeWidth: 1] + ], + splits: [ + text: [fill: "#000", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#bbb", stroke: "#000", strokeWidth: 1] + ] + ] + end + + @doc """ + A theme based on the [Monokai](https://monokai.pro/) color palette + """ + style :monokai do + [ + background: "#272822", + leaves: [ + text: [fill: "#f8f8f2", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#3e3d32", stroke: "#66d9ef", strokeWidth: 1] + ], + splits: [ + text: [fill: "#f8f8f2", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#66d9ef", stroke: "#f8f8f2", strokeWidth: 1], + children: [fill: "#f8f8f2", stroke: "#f8f8f2", strokeWidth: 1] + ], + yes: [ + text: [fill: "#a6e22e"], + path: [stroke: "#f8f8f2", strokeWidth: 1] + ], + no: [ + text: [fill: "#f92672"], + path: [stroke: "#f8f8f2", strokeWidth: 1] + ] + ] + end + + @doc """ + A theme based on the [Dracula](https://draculatheme.com/) color palette + """ + style :dracula do + [ + background: "#282a36", + leaves: [ + text: [fill: "#f8f8f2", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#44475a", stroke: "#ff79c6", strokeWidth: 1] + ], + splits: [ + text: [fill: "#f8f8f2", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#ff79c6", stroke: "#f8f8f2", strokeWidth: 1], + children: [fill: "#f8f8f2", stroke: "#f8f8f2", strokeWidth: 1] + ], + yes: [ + text: [fill: "#50fa7b"], + path: [stroke: "#f8f8f2", strokeWidth: 1] + ], + no: [ + text: [fill: "#ff5555"], + path: [stroke: "#f8f8f2", strokeWidth: 1] + ] + ] + end + + @doc """ + A theme based on the [Nord](https://www.nordtheme.com/) color palette + """ + style :nord do + [ + background: "#2e3440", + leaves: [ + text: [fill: "#d8dee9", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#3b4252", stroke: "#88c0d0", strokeWidth: 1] + ], + splits: [ + text: [fill: "#d8dee9", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#88c0d0", stroke: "#d8dee9", strokeWidth: 1], + children: [fill: "#d8dee9", stroke: "#d8dee9", strokeWidth: 1] + ], + yes: [ + text: [fill: "#a3be8c"], + path: [stroke: "#d8dee9", strokeWidth: 1] + ], + no: [ + text: [fill: "#bf616a"], + path: [stroke: "#d8dee9", strokeWidth: 1] + ] + ] + end + + @doc """ + A theme based on the [Material](https://material.io/design/color/the-color-system.html#tools-for-picking-colors) color palette + """ + style :material do + [ + background: "#263238", + leaves: [ + text: [fill: "#eceff1", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#37474f", stroke: "#80cbc4", strokeWidth: 1] + ], + splits: [ + text: [fill: "#eceff1", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#80cbc4", stroke: "#eceff1", strokeWidth: 1], + children: [fill: "#eceff1", stroke: "#eceff1", strokeWidth: 1] + ], + yes: [ + text: [fill: "#c5e1a5"], + path: [stroke: "#eceff1", strokeWidth: 1] + ], + no: [ + text: [fill: "#ef9a9a"], + path: [stroke: "#eceff1", strokeWidth: 1] + ] + ] + end + + @doc """ + A theme based on the One Dark color palette + """ + style :one_dark do + [ + background: "#282c34", + leaves: [ + text: [fill: "#abb2bf", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#3b4048", stroke: "#98c379", strokeWidth: 1] + ], + splits: [ + text: [fill: "#abb2bf", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#98c379", stroke: "#abb2bf", strokeWidth: 1], + children: [fill: "#abb2bf", stroke: "#abb2bf", strokeWidth: 1] + ], + yes: [ + text: [fill: "#98c379"], + path: [stroke: "#abb2bf", strokeWidth: 1] + ], + no: [ + text: [fill: "#e06c75"], + path: [stroke: "#abb2bf", strokeWidth: 1] + ] + ] + end + + @doc """ + A theme based on the [Gruvbox](https://github.com/morhetz/gruvbox) color palette + """ + style :gruvbox do + [ + background: "#282828", + leaves: [ + text: [fill: "#ebdbb2", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#3c3836", stroke: "#b8bb26", strokeWidth: 1] + ], + splits: [ + text: [fill: "#ebdbb2", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#b8bb26", stroke: "#ebdbb2", strokeWidth: 1], + children: [fill: "#ebdbb2", stroke: "#ebdbb2", strokeWidth: 1] + ], + yes: [ + text: [fill: "#b8bb26"], + path: [stroke: "#ebdbb2", strokeWidth: 1] + ], + no: [ + text: [fill: "#fb4934"], + path: [stroke: "#ebdbb2", strokeWidth: 1] + ] + ] + end + + @doc """ + A dark theme based on the [Horizon](https://www.horizon.io/) color palette + """ + style :horizon_dark do + [ + background: "#1C1E26", + leaves: [ + text: [fill: "#E3E6EE", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#232530", stroke: "#F43E5C", strokeWidth: 1] + ], + splits: [ + text: [fill: "#E3E6EE", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#F43E5C", stroke: "#E3E6EE", strokeWidth: 1], + children: [fill: "#E3E6EE", stroke: "#E3E6EE", strokeWidth: 1] + ], + yes: [ + text: [fill: "#48B685"], + path: [stroke: "#E3E6EE", strokeWidth: 1] + ], + no: [ + text: [fill: "#F43E5C"], + path: [stroke: "#E3E6EE", strokeWidth: 1] + ] + ] + end + + @doc """ + A light theme based on the [Horizon](https://www.horizon.io/) color palette + """ + style :horizon_light do + [ + background: "#FDF0ED", + leaves: [ + text: [fill: "#1A2026", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#F7E3D3", stroke: "#F43E5C", strokeWidth: 1] + ], + splits: [ + text: [fill: "#1A2026", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#F43E5C", stroke: "#1A2026", strokeWidth: 1], + children: [fill: "#1A2026", stroke: "#1A2026", strokeWidth: 1] + ], + yes: [ + text: [fill: "#48B685"], + path: [stroke: "#1A2026", strokeWidth: 1] + ], + no: [ + text: [fill: "#F43E5C"], + path: [stroke: "#1A2026", strokeWidth: 1] + ] + ] + end + + HTTPoison.start() + + @schema HTTPoison.get!("https://vega.github.io/schema/vega/v5.json").body + |> Jason.decode!() + |> ExJsonSchema.Schema.resolve() + + @mark_text_doc "Accepts a keyword list of Vega `text` Mark properties. Reference [here](https://vega.github.io/vega/docs/marks/text/) for more details. Accepts either a string (expected to be valid Vega property names) or Elixir-styled atom. Note that keys are snake-cased instead of camel-case (e.g. Vega `fontSize` becomes `font_size`)" + @mark_rect_doc "Accepts a keyword list of Vega `rect` Mark properties. Reference [here](https://vega.github.io/vega/docs/marks/rect/) for more details. Accepts either a string (expected to be valid Vega property names) or Elixir-styled atom. Note that keys are snake-cased instead of camel-case (e.g. Vega `fontSize` becomes `font_size`)" + @mark_path_doc "Accepts a keyword list of Vega `path` Mark properties. Reference [here](https://vega.github.io/vega/docs/marks/path/) for more details. Accepts either a string (expected to be valid Vega property names) or Elixir-styled atom. Note that keys are snake-cased instead of camel-case (e.g. Vega `fontSize` becomes `font_size`)" + + @mark_opts [ + type: :keyword_list, + keys: [ + *: [type: {:or, [:string, :atom, :integer, :boolean, nil, {:list, :any}]}] + ] + ] + + @default_leaf_text [ + align: :center, + baseline: :middle, + font_size: 13, + font: "Calibri" + ] + + @default_leaf_rect [ + corner_radius: 2, + opacity: 1 + ] + + @default_split_text [ + align: :center, + baseline: :middle, + font_size: 13, + font: "Calibri" + ] + + @default_split_rect [ + corner_radius: 2, + opacity: 1 + ] + + @default_split_children [ + align: :right, + baseline: :middle, + font: "Calibri", + font_size: 13 + ] + + @default_yes_path [] + + @default_yes_text [ + align: :center, + baseline: :middle, + font_size: 13, + font: "Calibri", + text: "yes" + ] + + @default_no_path [] + + @default_no_text [ + align: :center, + baseline: :middle, + font_size: 13, + font: "Calibri", + text: "no" + ] + + @plotting_params [ + style: [ + doc: + "The style to use for the visualization. Refer to `EXGBoost.Plotting.Styles` for a list of available styles.", + default: Application.compile_env(:exgboost, [:plotting, :style], :dracula), + type: {:or, [{:in, Keyword.keys(@styles)}, {:in, [nil, false]}, :keyword_list]} + ], + rankdir: [ + doc: "Determines the direction of the graph.", + type: {:in, [:tb, :lr, :bt, :rl]}, + default: Application.compile_env(:exgboost, [:plotting, :rankdir], :tb) + ], + autosize: [ + doc: + "Determines if the visualization should automatically resize when the window size changes", + type: {:in, ["fit", "pad", "fit-x", "fit-y", "none"]}, + default: Application.compile_env(:exgboost, [:plotting, :autosize], "fit") + ], + background: [ + doc: + "The background color of the visualization. Accepts a valid CSS color string. For example: `#f304d3`, `#ccc`, `rgb(253, 12, 134)`, `steelblue.`", + type: :string, + default: Application.compile_env(:exgboost, [:plotting, :background], "#f5f5f5") + ], + height: [ + doc: "Height of the plot in pixels", + type: :pos_integer, + default: Application.compile_env(:exgboost, [:plotting, :height], 400) + ], + width: [ + doc: "Width of the plot in pixels", + type: :pos_integer, + default: Application.compile_env(:exgboost, [:plotting, :width], 600) + ], + padding: [ + doc: + "The padding in pixels to add around the visualization. If a number, specifies padding for all sides. If an object, the value should have the format `[left: value, right: value, top: value, bottom: value]`", + type: + {:or, + [ + :pos_integer, + keyword_list: [ + left: [type: :pos_integer], + right: [type: :pos_integer], + top: [type: :pos_integer], + bottom: [type: :pos_integer] + ] + ]}, + default: Application.compile_env(:exgboost, [:plotting, :padding], 30) + ], + leaves: [ + doc: "Specifies characteristics of leaf nodes", + type: :keyword_list, + keys: [ + text: + @mark_opts ++ + [ + doc: @mark_text_doc, + default: + deep_merge_kw( + @default_leaf_text, + Application.compile_env( + :exgboost, + [:plotting, :leaves, :text], + [] + ) + ) + ], + rect: + @mark_opts ++ + [ + doc: @mark_rect_doc, + default: + deep_merge_kw( + @default_leaf_rect, + Application.compile_env( + :exgboost, + [:plotting, :leaves, :rect], + [] + ) + ) + ] + ], + default: [ + text: + deep_merge_kw( + @default_leaf_text, + Application.compile_env(:exgboost, [:plotting, :leaves, :text], []) + ), + rect: + deep_merge_kw( + @default_leaf_rect, + Application.compile_env(:exgboost, [:plotting, :leaves, :rect], []) + ) + ] + ], + splits: [ + doc: "Specifies characteristics of split nodes", + type: :keyword_list, + keys: [ + text: + @mark_opts ++ + [ + doc: @mark_text_doc, + default: + deep_merge_kw( + @default_split_text, + Application.compile_env( + :exgboost, + [:plotting, :splits, :text], + [] + ) + ) + ], + rect: + @mark_opts ++ + [ + doc: @mark_rect_doc, + default: + deep_merge_kw( + @default_split_rect, + Application.compile_env( + :exgboost, + [:plotting, :splits, :rect], + [] + ) + ) + ], + children: + @mark_opts ++ + [ + doc: @mark_text_doc, + default: + deep_merge_kw( + @default_split_children, + Application.compile_env( + :exgboost, + [:plotting, :splits, :children], + [] + ) + ) + ] + ], + default: [ + text: + deep_merge_kw( + @default_split_text, + Application.compile_env(:exgboost, [:plotting, :splits, :text], []) + ), + rect: + deep_merge_kw( + @default_split_rect, + Application.compile_env(:exgboost, [:plotting, :splits, :rect], []) + ), + children: + deep_merge_kw( + @default_split_children, + Application.compile_env( + :exgboost, + [:plotting, :splits, :children], + [] + ) + ) + ] + ], + node_width: [ + doc: "The width of each node in pixels", + type: :pos_integer, + default: Application.compile_env(:exgboost, [:plotting, :node_width], 100) + ], + node_height: [ + doc: "The height of each node in pixels", + type: :pos_integer, + default: Application.compile_env(:exgboost, [:plotting, :node_heigh], 45) + ], + space_between: [ + doc: "The space between the rectangular marks in pixels.", + type: :keyword_list, + keys: [ + nodes: [ + doc: "Space between marks within the same depth of the tree.", + type: :pos_integer, + default: Application.compile_env(:exgboost, [:plotting, :space_between, :nodes], 10) + ], + levels: [ + doc: "Space between each rank / depth of the tree.", + type: :pos_integer, + default: Application.compile_env(:exgboost, [:plotting, :space_between, :levels], 100) + ] + ], + default: [ + nodes: Application.compile_env(:exgboost, [:plotting, :space_between, :nodes], 10), + levels: Application.compile_env(:exgboost, [:plotting, :space_between, :levels], 100) + ] + ], + yes: [ + doc: "Specifies characteristics of links between nodes where the split condition is true", + type: :keyword_list, + keys: [ + path: + @mark_opts ++ + [ + doc: @mark_path_doc, + default: + deep_merge_kw( + @default_yes_path, + Application.compile_env(:exgboost, [:plotting, :yes, :path], []) + ) + ], + text: + @mark_opts ++ + [ + doc: @mark_text_doc, + default: + deep_merge_kw( + @default_yes_text, + Application.compile_env(:exgboost, [:plotting, :yes, :text], []) + ) + ] + ], + default: [ + path: + deep_merge_kw( + @default_yes_path, + Application.compile_env(:exgboost, [:plotting, :yes, :path], []) + ), + text: + deep_merge_kw( + @default_yes_text, + Application.compile_env(:exgboost, [:plotting, :yes, :text], []) + ) + ] + ], + no: [ + doc: "Specifies characteristics of links between nodes where the split condition is false", + type: :keyword_list, + keys: [ + path: + @mark_opts ++ + [ + doc: @mark_path_doc, + default: + deep_merge_kw( + @default_no_path, + Application.compile_env(:exgboost, [:plotting, :no, :path], []) + ) + ], + text: + @mark_opts ++ + [ + doc: @mark_text_doc, + default: + deep_merge_kw( + @default_no_text, + Application.compile_env(:exgboost, [:plotting, :no, :path], []) + ) + ] + ], + default: [ + path: + deep_merge_kw( + @default_no_path, + Application.compile_env(:exgboost, [:plotting, :no, :path], []) + ), + text: + deep_merge_kw( + @default_no_text, + Application.compile_env(:exgboost, [:plotting, :no, :path], []) + ) + ] + ], + validate: [ + doc: "Whether to validate the Vega specification against the Vega schema", + type: :boolean, + default: Application.compile_env(:exgboost, [:plotting, :validate], true) + ], + index: [ + doc: + "The zero-indexed index of the tree to plot. If `nil`, plots all trees using a dropdown selector to switch between trees", + type: {:or, [nil, :non_neg_integer]}, + default: Application.compile_env(:exgboost, [:plotting, :index], 0) + ], + depth: [ + doc: + "The depth of the tree to plot. If `nil`, plots all levels (cick on a node to expand/collapse)", + type: {:or, [nil, :non_neg_integer]}, + default: Application.compile_env(:exgboost, [:plotting, :depth], nil) + ] + ] + + @plotting_schema NimbleOptions.new!(@plotting_params) + @defaults NimbleOptions.validate!([], @plotting_schema) + + @moduledoc """ + Functions for plotting EXGBoost `Booster` models using [Vega](https://vega.github.io/vega/) + + Fundamentally, all this module does is convert a `Booster` into a format that can be + ingested by Vega, and apply some default configuations that only account for a subset of the configurations + that can be set by a Vega spec directly. The functions provided in this module are designed to have opinionated + defaults that can be used to quickly visualize a model, but the full power of Vega is available by using the + `to_tabular/1` function to convert the model into a tabular format, and then using the `plot/2` function + to convert the tabular format into a Vega specification. + + ## Default Vega Specification + + The default Vega specification is designed to be a good starting point for visualizing a model, but it is + possible to customize the specification by passing in a map of Vega properties to the `plot/2` function. + Refer to `Custom Vega Specifications` for more details on how to do this. + + By default, the Vega specification includes the following entities to use for rendering the model: + * `:width` - The width of the plot in pixels + * `:height` - The height of the plot in pixels + * `:padding` - The padding in pixels to add around the visualization. If a number, specifies padding for all sides. If an object, the value should have the format `[left: value, right: value, top: value, bottom: value]` + * `:leafs` - Specifies characteristics of leaf nodes + * `:inner_nodes` - Specifies characteristics of inner nodes + * `:links` - Specifies characteristics of links between nodes + + ## Custom Vega Specifications + + The default Vega specification is designed to be a good starting point for visualizing a model, but it is + possible to customize the specification by passing in a map of Vega properties to the `plot/2` function. + You can find the full list of Vega properties [here](https://vega.github.io/vega/docs/specification/). + + It is suggested that you use the data attributes provided by the default specification as a starting point, since they + provide the necessary data transformation to convert the model into a tree structure that can be visualized by Vega. + If you would like to customize the default specification, you can use `EXGBoost.Plotting.plot/1` to get the default + specification, and then modify it as needed. + + Once you have a custom specification, you can pass it to `VegaLite.from_json/1` to create a new `VegaLite` struct, after which + you can use the functions provided by the `VegaLite` module to render the model. + + ## Specification Validation + You can optionally validate your specification against the Vega schema by passing the `validate: true` option to `plot/2`. + This will raise an error if the specification is invalid. This is useful if you are creating a custom specification and want + to ensure that it is valid. Note that this will only validate the specification against the Vega schema, and not against the + VegaLite schema. This requires the [`ex_json_schema`] package to be installed. + + ## Livebook Integration + + This module also provides a `Kino.Render` implementation for `EXGBoost.Booster` which allows + models to be rendered directly in Livebook. This is done by converting the model into a Vega specification + and then using the `Kino.Render` implementation for Elixir's [`VegaLite`](https://hexdocs.pm/vega_lite/VegaLite.html) API + to render the model. + + . The Vega specification is then passed to [VegaLite](https://hexdocs.pm/vega_lite/readme.html) + + ## Plotting Parameters + + This module exposes a high-level API for customizing the EXGBoost model visualization, but it is also possible to + customize the Vega specification directly. You can also choose to pass in Vega Mark specifications to customize the + appearance of the nodes and links in the visualization outside of the parameteres specified below. Refer to the + [Vega documentation](https://vega.github.io/vega/docs/marks/) for more details on how to do this. + + #{NimbleOptions.docs(@plotting_params)} + + + ## Styles + + Styles are a keyword-map that adhere to the plotting schema as defined in `EXGBoost.Plotting`. + `EXGBoost.Plotting.Styles` provides a set of predefined styles that can be used to quickly customize the appearance of the visualization. + + + Refer to the `EXGBoost.Plotting.Styles` module for a list of available styles. You can pass a style to the `:style` + option as an atom or string, and it will be applied to the visualization. Styles will override any other options that are passed + for each option where the style defined a value. For example, if you pass `:solarized_light` as the style, and also pass + `:background` as an option, the `:background` option will be ignored since the `:solarized_light` style defines its own value for `:background`. + """ + + @spec get_schema() :: ExJsonSchema.Schema.Root.t() + def get_schema(), do: @schema + + @spec get_defaults() :: Keyword.t() + def get_defaults(), do: @defaults + + @spec get_styles() :: [{atom(), [style(), ...]}, ...] + def get_styles(), do: @styles + + defp validate_spec(spec) do + case ExJsonSchema.Validator.validate(get_schema(), spec) do + :ok -> + spec + + {:error, errors} -> + raise( + ArgumentError, + "Invalid Vega specification: #{inspect(errors)}" + ) + end + end + + defp _to_tabular(idx, node, parent) do + node = Map.put(node, "tree_id", idx) + node = if parent, do: Map.put(node, "parentid", parent), else: node + + node = + new_node(node) + |> Map.update("nodeid", nil, &(&1 + 1)) + |> Map.update("yes", nil, &if(&1, do: &1 + 1)) + |> Map.update("no", nil, &if(&1, do: &1 + 1)) + + if Map.has_key?(node, "children") do + [ + Map.delete(node, "children") + | Enum.flat_map(Map.get(node, "children"), fn child -> + _to_tabular(idx, child, Map.get(node, "nodeid")) + end) + ] + else + [node] + end + end + + defp new_node(params = %{}) do + Map.merge( + %{ + "nodeid" => nil, + "depth" => nil, + "missing" => nil, + "no" => nil, + "leaf" => nil, + "split" => nil, + "split_condition" => nil, + "yes" => nil, + "tree_id" => nil, + "parentid" => nil + }, + params + ) + end + + @doc """ + Outputs details of the tree in a tabular format which can be consumed + by plotting libraries such as Vega. Outputs as a list of maps, where + each map represents a node in the tree. + + Table columns: + - tree_id: The tree id + - nodeid: The node id + - parentid: The parent node id + - split: The split feature + - split_condition: The split condition + - yes: The node id of the left child + - no: The node id of the right child + - missing: The node id of the missing child + - depth: The depth of the node (root node is depth 1) + - leaf: The leaf value if it is a leaf node + """ + @spec to_tabular(EXGBoost.Booster.t()) :: [map()] + def to_tabular(booster) do + booster + |> EXGBoost.Booster.get_dump(format: :json) + |> Enum.map(&Jason.decode!(&1)) + |> Enum.with_index() + |> Enum.reduce( + [], + fn {tree, index}, acc -> + acc ++ _to_tabular(index, tree, nil) + end + ) + end + + @doc """ + Generates the necessary data transformation to convert the model into a tree structure that can be visualized by Vega. + + This function is useful if you want to create a custom Vega specification, but still want to use the data transformation + provided by the default specification. + + ## Options + + * `:rankdir` - Determines the direction of the graph. Accepts one of `:tb`, `:lr`, `:bt`, or `:rl`. Defaults to `:tb` + + """ + @spec get_data_spec(booster :: EXGBoost.Booster.t(), opts :: keyword()) :: map() + def get_data_spec(booster, opts \\ []) do + opts = NimbleOptions.validate!(opts, @plotting_schema) + + opts = + cond do + opts[:style] in [nil, false] -> + opts + + Keyword.keyword?(opts[:style]) -> + deep_merge_kw(opts[:style], opts, @defaults) + + true -> + style = apply(__MODULE__, opts[:style], []) + deep_merge_kw(style, opts, @defaults) + end + |> then(&deep_merge_kw(@defaults, &1)) + + %{ + "$schema" => "https://vega.github.io/schema/vega/v5.json", + "data" => [ + %{"name" => "tree", "values" => to_tabular(booster)}, + %{ + "name" => "treeCalcs", + "source" => "tree", + "transform" => [ + %{"expr" => "datum.tree_id === selectedTree", "type" => "filter"}, + %{"key" => "nodeid", "parentKey" => "parentid", "type" => "stratify"}, + %{ + "as" => + case opts[:rankdir] do + vertical when vertical in [:tb, :bt] -> + ["x", "y", "depth", "children"] + + horizontal when horizontal in [:lr, :rl] -> + ["y", "x", "depth", "children"] + end, + "method" => "tidy", + "separation" => %{"signal" => "false"}, + "type" => "tree" + } + ] + }, + %{ + "name" => "treeChildren", + "source" => "treeCalcs", + "transform" => [ + %{ + "as" => ["childrenObjects"], + "fields" => ["parentid"], + "groupby" => ["parentid"], + "ops" => ["values"], + "type" => "aggregate" + }, + %{ + "as" => "childrenIds", + "expr" => "pluck(datum.childrenObjects,'nodeid')", + "type" => "formula" + } + ] + }, + %{ + "name" => "treeAncestors", + "source" => "treeCalcs", + "transform" => [ + %{ + "as" => "treeAncestors", + "expr" => "treeAncestors('treeCalcs', datum.nodeid, 'root')", + "type" => "formula" + }, + %{"fields" => ["treeAncestors"], "type" => "flatten"}, + %{ + "as" => "allParents", + "expr" => "datum.treeAncestors.parentid", + "type" => "formula" + } + ] + }, + %{ + "name" => "treeChildrenAll", + "source" => "treeAncestors", + "transform" => [ + %{ + "fields" => [ + "allParents", + "nodeid", + "name", + "parentid", + "x", + "y", + "depth", + "children" + ], + "type" => "project" + }, + %{ + "as" => ["allChildrenObjects", "allChildrenCount", "id"], + "fields" => ["parentid", "parentid", "nodeid"], + "groupby" => ["allParents"], + "ops" => ["values", "count", "min"], + "type" => "aggregate" + }, + %{ + "as" => "allChildrenIds", + "expr" => "pluck(datum.allChildrenObjects,'nodeid')", + "type" => "formula" + } + ] + }, + %{ + "name" => "treeClickStoreTemp", + "source" => "treeAncestors", + "transform" => [ + %{ + "expr" => + "startingDepth != -1 ? datum.depth <= startingDepth : node != 0 && !isExpanded ? datum.parentid == node: node != 0 && isExpanded ? datum.allParents == node : false", + "type" => "filter" + }, + %{ + "fields" => ["nodeid", "parentid", "x", "y", "depth", "children"], + "type" => "project" + }, + %{ + "fields" => ["nodeid"], + "groupby" => ["nodeid", "parentid", "x", "y", "depth", "children"], + "ops" => ["min"], + "type" => "aggregate" + } + ] + }, + %{ + "name" => "treeClickStorePerm", + "on" => [ + %{"insert" => "data('treeClickStoreTemp')", "trigger" => "startingDepth >= 0"}, + %{ + "insert" => "!isExpanded ? data('treeClickStoreTemp'): false", + "trigger" => "node" + }, + %{"remove" => "isExpanded ? data('treeClickStoreTemp'): false", "trigger" => "node"} + ], + "values" => [] + }, + %{ + "name" => "treeLayout", + "source" => "tree", + "transform" => [ + %{"expr" => "datum.tree_id === selectedTree", "type" => "filter"}, + %{ + "expr" => "indata('treeClickStorePerm', 'nodeid', datum.nodeid)", + "type" => "filter" + }, + %{"key" => "nodeid", "parentKey" => "parentid", "type" => "stratify"}, + %{ + "as" => + case opts[:rankdir] do + vertical when vertical in [:tb, :bt] -> + ["x", "y", "depth", "children"] + + horizontal when horizontal in [:lr, :rl] -> + ["y", "x", "depth", "children"] + end, + "method" => "tidy", + "nodeSize" => [ + %{"signal" => "nodeWidth + spaceBetweenNodes"}, + %{"signal" => "nodeHeight+ spaceBetweenLevels"} + ], + "separation" => %{"signal" => "false"}, + "type" => "tree" + }, + %{ + "as" => "y", + "expr" => "#{if opts[:rankdir] == :bt, do: -1, else: 1}*(datum.y+(height/10))", + "type" => "formula" + }, + %{ + "as" => "x", + "expr" => "#{if opts[:rankdir] == :rl, do: -1, else: 1}*(datum.x+(width/2))", + "type" => "formula" + }, + %{"type" => "extent", "field" => "x", "signal" => "x_extent"}, + %{"type" => "extent", "field" => "y", "signal" => "y_extent"}, + %{"as" => "xscaled", "expr" => "scale('xscale',datum.x)", "type" => "formula"}, + %{"as" => "parent", "expr" => "datum.parentid", "type" => "formula"} + ] + }, + %{ + "name" => "fullTreeLayout", + "source" => "treeLayout", + "transform" => [ + %{ + "fields" => ["nodeid"], + "from" => "treeChildren", + "key" => "parentid", + "type" => "lookup", + "values" => ["childrenObjects", "childrenIds"] + }, + %{ + "fields" => ["nodeid"], + "from" => "treeChildrenAll", + "key" => "allParents", + "type" => "lookup", + "values" => ["allChildrenIds", "allChildrenObjects"] + }, + %{ + "fields" => ["nodeid"], + "from" => "treeCalcs", + "key" => "nodeid", + "type" => "lookup", + "values" => ["children"] + }, + %{ + "as" => "treeParent", + "expr" => "reverse(pluck(treeAncestors('treeCalcs', datum.nodeid), 'nodeid'))[1]", + "type" => "formula" + }, + %{"as" => "isLeaf", "expr" => "datum.leaf == null", "type" => "formula"} + ] + }, + %{ + "name" => "splitNodes", + "source" => "fullTreeLayout", + "transform" => [ + %{ + "type" => "filter", + "expr" => "indata('treeClickStorePerm', 'nodeid', datum.nodeid) && datum.isLeaf" + } + ] + }, + %{ + "name" => "leafNodes", + "source" => "fullTreeLayout", + "transform" => [ + %{ + "type" => "filter", + "expr" => "indata('treeClickStorePerm', 'nodeid', datum.nodeid) && !datum.isLeaf" + } + ] + }, + %{ + "name" => "links", + "source" => "treeLayout", + "transform" => [ + %{"type" => "treelinks"}, + %{ + "orient" => + case opts[:rankdir] do + vertical when vertical in [:tb, :bt] -> + "vertical" + + horizontal when horizontal in [:lr, :rl] -> + "horizontal" + end, + "shape" => "line", + "sourceX" => %{ + "expr" => + case opts[:rankdir] do + vertical when vertical in [:tb, :bt] -> + "scale('xscale', datum.source.x)" + + :lr -> + "scale('xscale', datum.source.x) + scaledNodeWidth/2" + + :rl -> + "scale('xscale', datum.source.x) - scaledNodeWidth/2" + end + }, + "sourceY" => %{ + "expr" => + case opts[:rankdir] do + vertical when vertical in [:tb, :bt] -> + "scale('yscale', datum.source.y)" + + horizontal when horizontal in [:lr, :rl] -> + "scale('yscale', datum.source.y) - scaledNodeHeight/2" + end + }, + "targetX" => %{ + "expr" => + case opts[:rankdir] do + vertical when vertical in [:tb, :bt] -> + "scale('xscale', datum.target.x)" + + :lr -> + "scale('xscale', datum.target.x) - scaledNodeWidth/2" + + :rl -> + "scale('xscale', datum.target.x) + scaledNodeWidth/2" + end + }, + "targetY" => %{ + "expr" => + case opts[:rankdir] do + vertical when vertical in [:tb, :bt] -> + "scale('yscale', datum.target.y) - scaledNodeHeight" + + horizontal when horizontal in [:lr, :rl] -> + "scale('yscale', datum.target.y) - scaledNodeHeight/2" + end + }, + "type" => "linkpath" + }, + %{ + "expr" => " indata('treeClickStorePerm', 'nodeid', datum.target.nodeid)", + "type" => "filter" + } + ] + }, + %{ + "name" => "yesPaths", + "source" => "links", + "transform" => [ + %{ + "type" => "filter", + "expr" => "datum.source.yes === datum.target.nodeid " + } + ] + }, + %{ + "name" => "noPaths", + "source" => "links", + "transform" => [ + %{ + "type" => "filter", + "expr" => "datum.source.yes !== datum.target.nodeid " + } + ] + } + ] + } + end + + @spec plot(EXGBoost.Booster.t(), Keyword.t()) :: VegaLite.t() + def plot(booster, opts \\ []) do + spec = get_data_spec(booster, opts) + opts = NimbleOptions.validate!(opts, @plotting_schema) + + opts = + cond do + opts[:style] in [nil, false] -> + opts + + Keyword.keyword?(opts[:style]) -> + deep_merge_kw(opts[:style], opts, @defaults) + + true -> + style = apply(__MODULE__, opts[:style], []) + deep_merge_kw(style, opts, @defaults) + end + |> then(&deep_merge_kw(@defaults, &1)) + + # Try to account for non-default node height / width to adjust spacing + # between nodes and levels as a quality of life improvement + # If they provide their own spacing, don't adjust it + opts = + cond do + opts[:rankdir] in [:lr, :rl] and + opts[:space_between][:levels] == @defaults[:space_between][:levels] -> + put_in( + opts[:space_between][:levels], + @defaults[:space_between][:levels] + (opts[:node_width] - @defaults[:node_width]) + ) + + opts[:rankdir] in [:tb, :bt] and + opts[:space_between][:levels] == @defaults[:space_between][:levels] -> + put_in( + opts[:space_between][:levels], + @defaults[:space_between][:levels] + (opts[:node_height] - @defaults[:node_height]) + ) + + true -> + opts + end + + opts = + cond do + opts[:rankdir] in [:lr, :rl] and + opts[:space_between][:nodes] == @defaults[:space_between][:nodes] -> + put_in( + opts[:space_between][:nodes], + @defaults[:space_between][:nodes] + (opts[:node_height] - @defaults[:node_height]) + ) + + opts[:rankdir] in [:tb, :bt] and + opts[:space_between][:nodes] == @defaults[:space_between][:nodes] -> + put_in( + opts[:space_between][:nodes], + @defaults[:space_between][:nodes] + (opts[:node_width] - @defaults[:node_width]) + ) + + true -> + opts + end + + [%{"$ref" => root} | [%{"properties" => properties}]] = + EXGBoost.Plotting.get_schema().schema |> Map.get("allOf") + + top_level_keys = + (ExJsonSchema.Schema.get_fragment!(EXGBoost.Plotting.get_schema(), root) + |> Map.get("properties") + |> Map.keys()) ++ Map.keys(properties) + + tlk = opts |> opts_to_vl_props() |> Map.take(top_level_keys) + + spec = + Map.merge(spec, %{ + "$schema" => "https://vega.github.io/schema/vega/v5.json", + "marks" => [ + %{ + "encode" => + Map.merge( + %{ + "update" => %{ + "path" => %{"field" => "path"}, + "strokeWidth" => %{ + "signal" => "indexof(nodeHighlight, datum.target.nodeid)> -1? 2:1" + } + } + }, + format_mark(opts[:yes][:path]) + ), + "from" => %{"data" => "yesPaths"}, + "interactive" => false, + "type" => "path" + }, + %{ + "encode" => %{ + "update" => + Map.merge( + %{ + "path" => %{"field" => "path"}, + "strokeWidth" => %{ + "signal" => "indexof(nodeHighlight, datum.target.nodeid)> -1? 2:1" + } + }, + format_mark(opts[:no][:path]) + ) + }, + "from" => %{"data" => "noPaths"}, + "interactive" => false, + "type" => "path" + }, + %{ + "encode" => %{ + "update" => + deep_merge_maps( + %{ + "x" => %{ + "signal" => + "(scale('xscale', datum.source.x#{cond do + opts[:rankdir] in [:tb, :bt, :lr] -> ~c"-nodeWidth/4" + opts[:rankdir] in [:rl] -> ~c"+nodeWidth/4" + true -> ~c"" + end}) + scale('xscale', datum.target.x)) / 2" + }, + "y" => %{ + "signal" => + "(scale('yscale', datum.source.y#{if opts[:rankdir] in [:lr, :rl], do: ~c"-nodeWidth/3", else: ~c""}) + scale('yscale', datum.target.y)) / 2 - (scaledNodeHeight/2)" + } + }, + format_mark(opts[:yes][:text]) + ) + }, + "from" => %{"data" => "yesPaths"}, + "type" => "text" + }, + %{ + "encode" => %{ + "update" => + Map.merge( + %{ + "x" => %{ + "signal" => + "(scale('xscale', datum.source.x#{cond do + opts[:rankdir] in [:tb, :bt, :lr] -> ~c"-nodeWidth/4" + opts[:rankdir] in [:rl] -> ~c"+nodeWidth/4" + true -> ~c"" + end}) + scale('xscale', datum.target.x)) / 2" + }, + "y" => %{ + "signal" => + "(scale('yscale', datum.source.y) + scale('yscale', datum.target.y)) / 2 - (scaledNodeHeight/2)" + } + }, + format_mark(opts[:no][:text]) + ) + }, + "from" => %{"data" => "noPaths"}, + "type" => "text" + }, + %{ + "clip" => false, + "encode" => %{ + "update" => + Map.merge( + %{ + "cursor" => %{"signal" => "datum.children > 0 ? 'pointer' : '' "}, + "height" => %{"signal" => "scaledNodeHeight"}, + "tooltip" => %{"signal" => ""}, + "width" => %{"signal" => "scaledNodeWidth"}, + "x" => %{"signal" => "datum.xscaled - (scaledNodeWidth / 2)"}, + "yc" => %{"signal" => "scale('yscale',datum.y) - (scaledNodeHeight/2)"} + }, + format_mark(opts[:splits][:rect]) + ) + }, + "from" => %{"data" => "splitNodes"}, + "name" => "splitNode", + "marks" => [ + %{ + "encode" => %{ + "update" => + Map.merge( + %{ + "limit" => %{"signal" => "scaledNodeWidth-scaledLimit"}, + "text" => %{ + "signal" => + "parent.split + ' <= ' + format(parent.split_condition, '.2f')" + }, + "x" => %{"signal" => "(scaledNodeWidth / 2)"}, + "y" => %{"signal" => "scaledNodeHeight / 2"} + }, + format_mark(opts[:splits][:text]) + ) + }, + "interactive" => false, + "name" => "title", + "type" => "text" + }, + %{ + "encode" => %{ + "update" => + Map.merge( + %{ + "text" => %{"signal" => "parent.children"}, + "x" => %{"signal" => "item.mark.group.width - (9/ span(xdom))*width"}, + "y" => %{"signal" => "item.mark.group.height/2"} + }, + format_mark(opts[:splits][:children]) + ) + }, + "interactive" => false, + "type" => "text" + } + ], + "type" => "group" + }, + %{ + "clip" => false, + "encode" => %{ + "update" => + Map.merge( + %{ + "cursor" => %{"signal" => "datum.children > 0 ? 'pointer' : '' "}, + "height" => %{"signal" => "scaledNodeHeight"}, + "tooltip" => %{"signal" => ""}, + "width" => %{"signal" => "scaledNodeWidth"}, + "x" => %{"signal" => "datum.xscaled - (scaledNodeWidth / 2)"}, + "yc" => %{"signal" => "scale('yscale',datum.y) - (scaledNodeHeight/2)"} + }, + format_mark(opts[:leaves][:rect]) + ) + }, + "from" => %{"data" => "leafNodes"}, + "name" => "leafNode", + "marks" => [ + %{ + "encode" => %{ + "update" => + Map.merge( + %{ + "limit" => %{"signal" => "scaledNodeWidth-scaledLimit"}, + "text" => %{ + "signal" => "'leaf = ' + format(parent.leaf, '.2f')" + }, + "x" => %{"signal" => "scaledNodeWidth / 2"}, + "y" => %{"signal" => "scaledNodeHeight / 2"} + }, + format_mark(opts[:leaves][:text]) + ) + }, + "interactive" => false, + "name" => "title", + "type" => "text" + } + ], + "type" => "group" + } + ], + "scales" => [ + %{ + "domain" => %{"signal" => "xdom"}, + "name" => "xscale", + "range" => %{"signal" => "xrange"}, + "zero" => false + }, + %{ + "domain" => %{"signal" => "ydom"}, + "name" => "yscale", + "range" => %{"signal" => "yrange"}, + "zero" => false + } + ], + "signals" => [ + Map.merge( + %{ + "name" => "selectedTree", + "value" => opts[:index] || 0 + }, + if(opts[:index], + do: %{}, + else: %{ + "bind" => %{ + "input" => "select", + "options" => + to_tabular(booster) + |> Enum.reduce([], fn node, acc -> + id = Map.get(node, "tree_id") + if id in acc, do: acc, else: [id | acc] + end) + |> Enum.sort() + } + } + ) + ), + %{ + "name" => "node", + "on" => [ + %{ + "events" => %{"markname" => "splitNode", "type" => "click"}, + "update" => "datum.nodeid" + } + ], + "value" => 0 + }, + %{ + "name" => "nodeHighlight", + "on" => [ + %{ + "events" => [ + %{"markname" => "splitNode", "type" => "mouseover"}, + %{"markname" => "leafNode", "type" => "mouseover"} + ], + "update" => "pluck(treeAncestors('treeCalcs', datum.nodeid), 'nodeid')" + }, + %{"events" => %{"type" => "mouseout"}, "update" => "[0]"} + ], + "value" => "[0]" + }, + %{ + "name" => "isExpanded", + "on" => [ + %{ + "events" => %{"markname" => "splitNode", "type" => "click"}, + "update" => + "datum.children > 0 && indata('treeClickStorePerm', 'nodeid', datum.childrenIds[0]) ? true : false" + } + ], + "value" => 0 + }, + %{"name" => "xrange", "update" => "[0, width]"}, + %{"name" => "yrange", "update" => "[0, height]"}, + %{ + "name" => "down", + "on" => [%{"events" => "mousedown", "update" => "xy()"}], + "value" => nil + }, + %{ + "name" => "xcur", + "on" => [%{"events" => "mousedown", "update" => "slice(xdom)"}], + "value" => nil + }, + %{ + "name" => "ycur", + "on" => [%{"events" => "mousedown", "update" => "slice(ydom)"}], + "value" => nil + }, + %{ + "name" => "delta", + "on" => [ + %{ + "events" => [ + %{ + "between" => [ + %{"type" => "mousedown"}, + %{"source" => "window", "type" => "mouseup"} + ], + "consume" => true, + "source" => "window", + "type" => "mousemove" + } + ], + "update" => "down ? [down[0]-x(), down[1]-y()] : [0,0]" + } + ], + "value" => [0, 0] + }, + %{ + "name" => "anchor", + "on" => [ + %{"events" => "wheel", "update" => "[invert('xscale', x()), invert('yscale', y())]"} + ], + "value" => [0, 0] + }, + %{"name" => "xext", "update" => "[0,width]"}, + %{"name" => "yext", "update" => "[0,height]"}, + %{ + "name" => "zoom", + "on" => [ + %{ + "events" => "wheel!", + "force" => true, + "update" => "pow(1.001, event.deltaY * pow(16, event.deltaMode))" + } + ], + "value" => 1 + }, + %{ + "name" => "xdom", + "on" => [ + %{ + "events" => %{"signal" => "delta"}, + "update" => + "[xcur[0] + span(xcur) * delta[0] / width, xcur[1] + span(xcur) * delta[0] / width]" + }, + %{ + "events" => %{"signal" => "zoom"}, + "update" => + "[anchor[0] + (xdom[0] - anchor[0]) * zoom, anchor[0] + (xdom[1] - anchor[0]) * zoom]" + } + ], + "update" => "[x_extent[0] - nodeWidth/ 2, x_extent[1] + nodeWidth / 2]" + }, + %{ + "name" => "ydom", + "on" => [ + %{ + "events" => %{"signal" => "delta"}, + "update" => + "[ycur[0] + span(ycur) * delta[1] / height, ycur[1] + span(ycur) * delta[1] / height]" + }, + %{ + "events" => %{"signal" => "zoom"}, + "update" => + "[anchor[1] + (ydom[0] - anchor[1]) * zoom, anchor[1] + (ydom[1] - anchor[1]) * zoom]" + } + ], + "update" => "[y_extent[0] - nodeHeight, y_extent[1] + nodeHeight/3]" + }, + %{"name" => "scaledNodeWidth", "update" => "(nodeWidth/ span(xdom))*width"}, + %{"name" => "scaledNodeHeight", "update" => "abs(nodeHeight/ span(ydom))*height"}, + %{"name" => "scaledLimit", "update" => "(20/ span(xdom))*width"}, + %{"name" => "spaceBetweenLevels", "value" => opts[:space_between][:levels]}, + %{"name" => "spaceBetweenNodes", "value" => opts[:space_between][:nodes]}, + %{"name" => "nodeWidth", "value" => opts[:node_width]}, + %{"name" => "nodeHeight", "value" => opts[:node_height]}, + %{ + "name" => "startingDepth", + "on" => [%{"events" => %{"throttle" => 0, "type" => "timer"}, "update" => "-1"}], + "value" => + opts[:depth] || + to_tabular(booster) + |> Enum.map(&Map.get(&1, "depth")) + |> Enum.filter(& &1) + |> Enum.max() + } + ] + }) + + spec = Map.merge(spec, tlk) |> Map.delete("style") + spec = if opts[:validate], do: validate_spec(spec), else: spec + + Jason.encode!(spec) |> VegaLite.from_json() + end + + # Helpers (from https://github.com/livebook-dev/vega_lite/blob/v0.1.8/lib/vega_lite.ex#L1094) + + defp opts_to_vl_props(opts) do + opts |> Map.new() |> to_vl() + end + + defp to_vl(value) when value in [true, false, nil], do: value + + defp to_vl(atom) when is_atom(atom), do: to_vl_key(atom) + + defp to_vl(%_{} = struct), do: struct + + defp to_vl(map) when is_map(map) do + Map.new(map, fn {key, value} -> + {to_vl(key), to_vl(value)} + end) + end + + defp to_vl([{key, _} | _] = keyword) when is_atom(key) do + Map.new(keyword, fn {key, value} -> + {to_vl(key), to_vl(value)} + end) + end + + defp to_vl(list) when is_list(list) do + Enum.map(list, &to_vl/1) + end + + defp to_vl(value), do: value + + defp to_vl_key(key) when is_atom(key) do + key |> to_string() |> snake_to_camel() + end + + defp format_mark(opts) do + opts + |> opts_to_vl_props() + |> Enum.into(%{}, fn {key, value} -> + case key do + "fontSize" -> + {"fontSize", %{"signal" => "(#{value}/ span(xdom))*width"}} + + other -> + {other, %{"value" => value}} + end + end) + end + + defp snake_to_camel(string) do + [part | parts] = String.split(string, "_") + Enum.join([String.downcase(part, :ascii) | Enum.map(parts, &capitalize/1)]) + end + + defp capitalize(<>) when first in ?a..?z, do: <> + defp capitalize(rest), do: rest + + def __after_compile__(env) do + IO.inspect(env) + end +end + +defimpl Kino.Render, for: EXGBoost.Booster do + def to_livebook(booster) do + EXGBoost.Plotting.plot(booster) |> Kino.Render.to_livebook() + end +end diff --git a/lib/exgboost/plotting/style.ex b/lib/exgboost/plotting/style.ex new file mode 100644 index 0000000..b1ca860 --- /dev/null +++ b/lib/exgboost/plotting/style.ex @@ -0,0 +1,50 @@ +defmodule EXGBoost.Plotting.Style do + @moduledoc false + defmacro __using__(_opts) do + quote do + @type style :: Keyword.t() + import EXGBoost.Plotting.Style + Module.register_attribute(__MODULE__, :styles, accumulate: true) + end + end + + def deep_merge_kw(a, b, ignore_set \\ []) do + Keyword.merge(a, b, fn + _key, val_a, val_b when is_list(val_a) and is_list(val_b) -> + deep_merge_kw(val_a, val_b) + + key, val_a, val_b -> + if Keyword.has_key?(b, key) do + if Keyword.has_key?(ignore_set, key) and Keyword.get(ignore_set, key) == val_b do + val_a + else + val_b + end + else + val_a + end + end) + end + + def deep_merge_maps(b, a) do + Map.merge(a, b, fn + _key, val_a, val_b when is_map(val_a) and is_map(val_b) -> + deep_merge_maps(val_a, val_b) + + key, val_a, val_b -> + if Map.has_key?(b, key) do + val_b + else + val_a + end + end) + end + + defmacro style(style_name, do: body) do + quote do + @spec unquote(style_name)() :: style + def unquote(style_name)(), do: unquote(body) + Module.put_attribute(__MODULE__, :styles, {unquote(style_name), unquote(body)}) + end + end +end diff --git a/lib/exgboost/plotting/styles.ex b/lib/exgboost/plotting/styles.ex new file mode 100644 index 0000000..2c7a8bd --- /dev/null +++ b/lib/exgboost/plotting/styles.ex @@ -0,0 +1,25 @@ +defmodule EXGBoost.Plotting.Styles do + @bst File.cwd!() |> Path.join("test/data/model.json") |> EXGBoost.read_model() + + @moduledoc """ +
+ #{Enum.map(EXGBoost.Plotting.get_styles(), fn {name, _style} -> """ +
+

#{name}

+
+        
+        #{EXGBoost.plot_tree(@bst, style: name, height: 200, width: 300).spec |> Jason.encode!()}
+        
+      
+
+ """ end) |> Enum.join("\n\n")} +
+ """ + + @functions [{:a, "a"}, {:b, "b"}, {:c, "c"}] + for {name, ret} <- @functions do + def unquote(name)() do + unquote(ret) + end + end +end diff --git a/lib/exgboost/training.ex b/lib/exgboost/training.ex index 484e53f..3c58715 100644 --- a/lib/exgboost/training.ex +++ b/lib/exgboost/training.ex @@ -6,6 +6,8 @@ defmodule EXGBoost.Training do @spec train(DMatrix.t(), Keyword.t()) :: Booster.t() def train(%DMatrix{} = dmat, opts \\ []) do + dmat_opts = Keyword.take(opts, EXGBoost.Internal.dmatrix_feature_opts()) + valid_opts = [ callbacks: [], early_stopping_rounds: nil, @@ -13,13 +15,15 @@ defmodule EXGBoost.Training do learning_rates: nil, num_boost_rounds: 10, obj: nil, - verbose_eval: true + verbose_eval: true, + disable_default_eval_metric: false ] {opts, booster_params} = Keyword.split(opts, Keyword.keys(valid_opts)) [ callbacks: callbacks, + disable_default_eval_metric: disable_default_eval_metric, early_stopping_rounds: early_stopping_rounds, evals: evals, learning_rates: learning_rates, @@ -45,7 +49,7 @@ defmodule EXGBoost.Training do evals_dmats = Enum.map(evals, fn {x, y, name} -> - {DMatrix.from_tensor(x, y, format: :dense), name} + {DMatrix.from_tensor(x, y, Keyword.put_new(dmat_opts, :format, :dense)), name} end) bst = @@ -55,7 +59,14 @@ defmodule EXGBoost.Training do ) defaults = - default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds) + default_callbacks( + bst, + learning_rates, + verbose_eval, + evals_dmats, + early_stopping_rounds, + disable_default_eval_metric + ) callbacks = Enum.map(callbacks ++ defaults, fn %Callback{fun: fun} = callback -> @@ -119,7 +130,14 @@ defmodule EXGBoost.Training do %{state | iteration: iter} end - defp default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds) do + defp default_callbacks( + bst, + learning_rates, + verbose_eval, + evals_dmats, + early_stopping_rounds, + disable_default_eval_metric + ) do default_callbacks = [] default_callbacks = @@ -154,12 +172,24 @@ defmodule EXGBoost.Training do if early_stopping_rounds && evals_dmats != [] do [{_dmat, target_eval} | _tail] = Enum.reverse(evals_dmats) - # Default to the last metric - [%{"name" => metric_name} | _tail] = - EXGBoost.dump_config(bst) - |> Jason.decode!() - |> get_in(["learner", "metrics"]) - |> Enum.reverse() + # This is still somewhat hacky and relies on a modification made to + # XGBoost in the Makefile to dump the config to JSON. + # + %{"learner" => %{"metrics" => metrics, "default_metric" => default_metric}} = + EXGBoost.dump_config(bst) |> Jason.decode!() + + metric_name = + cond do + Enum.empty?(metrics) && disable_default_eval_metric -> + raise ArgumentError, + "`:early_stopping_rounds` requires at least one evaluation set. This means you have likely set `disable_default_eval_metric: true` and have not set any explicit evalutation metrics. Please supply at least one metric in the `:eval_metric` option or set `disable_default_eval_metric: false` (default option)" + + Enum.empty?(metrics) -> + default_metric + + true -> + metrics |> Enum.reverse() |> hd() |> Map.fetch!("name") + end early_stop = %Callback{ event: :after_iteration, diff --git a/lib/exgboost/training/callback.ex b/lib/exgboost/training/callback.ex index dadd0c0..8f873db 100644 --- a/lib/exgboost/training/callback.ex +++ b/lib/exgboost/training/callback.ex @@ -166,6 +166,9 @@ defmodule EXGBoost.Training.Callback do true -> early_stop = Map.update!(early_stop, :since_last_improvement, &(&1 + 1)) + # TODO: Should this actually update the best iteration and score? + # This iteration is not the best, but it is the last one, so do we want + # another way to track last iteration? bst = struct(bst, best_iteration: state.iteration, best_score: score) %{state | booster: bst, meta_vars: %{meta_vars | early_stop: early_stop}, status: :halt} end diff --git a/mix.exs b/mix.exs index 4c741c9..80be283 100644 --- a/mix.exs +++ b/mix.exs @@ -1,6 +1,6 @@ defmodule EXGBoost.MixProject do use Mix.Project - @version "0.4.0" + @version "0.5.0" def project do [ @@ -19,6 +19,9 @@ defmodule EXGBoost.MixProject do start_permanent: Mix.env() == :prod, compilers: [:elixir_make] ++ Mix.compilers(), deps: deps(), + name: "EXGBoost", + source_url: "https://github.com/acalejos/exgboost", + homepage_url: "https://github.com/acalejos/exgboost", docs: docs(), package: package(), preferred_cli_env: [ @@ -46,9 +49,15 @@ defmodule EXGBoost.MixProject do {:nimble_options, "~> 1.0"}, {:nx, "~> 0.5"}, {:jason, "~> 1.3"}, - {:ex_doc, "~> 0.29.0", only: :docs}, + {:ex_doc, "~> 0.31.0", only: :docs}, {:cc_precompiler, "~> 0.1.0", runtime: false}, - {:exterval, "0.1.0"} + {:exterval, "0.1.0"}, + {:ex_json_schema, "~> 0.10.2"}, + {:httpoison, "~> 2.0", runtime: false}, + {:vega_lite, "~> 0.1"}, + {:kino, "~> 0.11"}, + {:scidata, "~> 0.1", only: :dev}, + {:kino_vega_lite, "~> 0.1.9", only: :dev} ] end @@ -76,11 +85,27 @@ defmodule EXGBoost.MixProject do extras: [ "notebooks/compiled_benchmarks.livemd", "notebooks/iris_classification.livemd", - "notebooks/quantile_prediction_interval.livemd" + "notebooks/quantile_prediction_interval.livemd", + "notebooks/plotting.livemd" ], groups_for_extras: [ Notebooks: Path.wildcard("notebooks/*.livemd") ], + groups_for_functions: [ + "System / Native Config": &(&1[:type] == :system), + "Training & Prediction": &(&1[:type] == :train_pred), + Serialization: &(&1[:type] == :serialization), + Plotting: &(&1[:type] == :plotting) + ], + groups_for_modules: [ + Plotting: [EXGBoost.Plotting, EXGBoost.Plotting.Styles], + Training: [ + EXGBoost.Training, + EXGBoost.Training.Callback, + EXGBoost.Booster, + EXGBoost.Parameters + ] + ], before_closing_body_tag: &before_closing_body_tag/1 ] end @@ -122,6 +147,38 @@ defmodule EXGBoost.MixProject do } }); + + + + + + + """ end diff --git a/mix.lock b/mix.lock index 37368c9..73d39f7 100644 --- a/mix.lock +++ b/mix.lock @@ -1,16 +1,35 @@ %{ - "cc_precompiler": {:hex, :cc_precompiler, "0.1.7", "77de20ac77f0e53f20ca82c563520af0237c301a1ec3ab3bc598e8a96c7ee5d9", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "2768b28bf3c2b4f788c995576b39b8cb5d47eb788526d93bd52206c1d8bf4b75"}, + "castore": {:hex, :castore, "0.1.22", "4127549e411bedd012ca3a308dede574f43819fe9394254ca55ab4895abfa1a2", [:mix], [], "hexpm", "c17576df47eb5aa1ee40cc4134316a99f5cad3e215d5c77b8dd3cfef12a22cac"}, + "cc_precompiler": {:hex, :cc_precompiler, "0.1.9", "e8d3364f310da6ce6463c3dd20cf90ae7bbecbf6c5203b98bf9b48035592649b", [:mix], [{:elixir_make, "~> 0.7", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "9dcab3d0f3038621f1601f13539e7a9ee99843862e66ad62827b0c42b2f58a54"}, + "certifi": {:hex, :certifi, "2.12.0", "2d1cca2ec95f59643862af91f001478c9863c2ac9cb6e2f89780bfd8de987329", [:rebar3], [], "hexpm", "ee68d85df22e554040cdb4be100f33873ac6051387baf6a8f6ce82272340ff1c"}, "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, - "earmark_parser": {:hex, :earmark_parser, "1.4.32", "fa739a0ecfa34493de19426681b23f6814573faee95dfd4b4aafe15a7b5b32c6", [:mix], [], "hexpm", "b8b0dd77d60373e77a3d7e8afa598f325e49e8663a51bcc2b88ef41838cca755"}, - "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, - "ex_doc": {:hex, :ex_doc, "0.29.4", "6257ecbb20c7396b1fe5accd55b7b0d23f44b6aa18017b415cb4c2b91d997729", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "2c6699a737ae46cb61e4ed012af931b57b699643b24dabe2400a8168414bc4f5"}, + "decimal": {:hex, :decimal, "2.1.1", "5611dca5d4b2c3dd497dec8f68751f1f1a54755e8ed2a966c2633cf885973ad6", [:mix], [], "hexpm", "53cfe5f497ed0e7771ae1a475575603d77425099ba5faef9394932b35020ffcc"}, + "earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"}, + "elixir_make": {:hex, :elixir_make, "0.7.8", "505026f266552ee5aabca0b9f9c229cbb496c689537c9f922f3eb5431157efc7", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "7a71945b913d37ea89b06966e1342c85cfe549b15e6d6d081e8081c493062c07"}, + "ex_doc": {:hex, :ex_doc, "0.31.1", "8a2355ac42b1cc7b2379da9e40243f2670143721dd50748bf6c3b1184dae2089", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "3178c3a407c557d8343479e1ff117a96fd31bafe52a039079593fb0524ef61b0"}, + "ex_json_schema": {:hex, :ex_json_schema, "0.10.2", "7c4b8c1481fdeb1741e2ce66223976edfb9bccebc8014f6aec35d4efe964fb71", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "37f43be60f8407659d4d0155a7e45e7f406dab1f827051d3d35858a709baf6a6"}, "exterval": {:hex, :exterval, "0.1.0", "d35ec43b0f260239f859665137fac0974f1c6a8d50bf8d52b5999c87c67c63e5", [:mix], [], "hexpm", "8ed444558a501deec6563230e3124cdf242c413a95eb9ca9f39de024ad779d7f"}, - "jason": {:hex, :jason, "1.4.0", "e855647bc964a44e2f67df589ccf49105ae039d4179db7f6271dfd3843dc27e6", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "79a3791085b2a0f743ca04cec0f7be26443738779d09302e01318f97bdb82121"}, - "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, + "fss": {:hex, :fss, "0.1.1", "9db2344dbbb5d555ce442ac7c2f82dd975b605b50d169314a20f08ed21e08642", [:mix], [], "hexpm", "78ad5955c7919c3764065b21144913df7515d52e228c09427a004afe9c1a16b0"}, + "hackney": {:hex, :hackney, "1.20.1", "8d97aec62ddddd757d128bfd1df6c5861093419f8f7a4223823537bad5d064e2", [:rebar3], [{:certifi, "~> 2.12.0", [hex: :certifi, repo: "hexpm", optional: false]}, {:idna, "~> 6.1.0", [hex: :idna, repo: "hexpm", optional: false]}, {:metrics, "~> 1.0.0", [hex: :metrics, repo: "hexpm", optional: false]}, {:mimerl, "~> 1.1", [hex: :mimerl, repo: "hexpm", optional: false]}, {:parse_trans, "3.4.1", [hex: :parse_trans, repo: "hexpm", optional: false]}, {:ssl_verify_fun, "~> 1.1.0", [hex: :ssl_verify_fun, repo: "hexpm", optional: false]}, {:unicode_util_compat, "~> 0.7.0", [hex: :unicode_util_compat, repo: "hexpm", optional: false]}], "hexpm", "fe9094e5f1a2a2c0a7d10918fee36bfec0ec2a979994cff8cfe8058cd9af38e3"}, + "httpoison": {:hex, :httpoison, "2.2.1", "87b7ed6d95db0389f7df02779644171d7319d319178f6680438167d7b69b1f3d", [:mix], [{:hackney, "~> 1.17", [hex: :hackney, repo: "hexpm", optional: false]}], "hexpm", "51364e6d2f429d80e14fe4b5f8e39719cacd03eb3f9a9286e61e216feac2d2df"}, + "idna": {:hex, :idna, "6.1.1", "8a63070e9f7d0c62eb9d9fcb360a7de382448200fbbd1b106cc96d3d8099df8d", [:rebar3], [{:unicode_util_compat, "~> 0.7.0", [hex: :unicode_util_compat, repo: "hexpm", optional: false]}], "hexpm", "92376eb7894412ed19ac475e4a86f7b413c1b9fbb5bd16dccd57934157944cea"}, + "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, + "kino": {:hex, :kino, "0.12.3", "a5f48a243c60a7ac18ba23869f697b1c775fc7794e8cd55dd248ba33c6fe9445", [:mix], [{:fss, "~> 0.1.0", [hex: :fss, repo: "hexpm", optional: false]}, {:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "a6dfa3d54ba0edec9ca6e5940154916b381901001f171c85a2d8c67869dbc2d8"}, + "kino_vega_lite": {:hex, :kino_vega_lite, "0.1.11", "d3c2a00b3685b95f91833920d06cc9b1fd7fb293a2663d89affe9aaec16a5b77", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.8", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "5ccd9148ce7cfcc95a137e12596cd8b95b371e9ea107e745bc262c39c5d8d48e"}, + "makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, - "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, - "nimble_options": {:hex, :nimble_options, "1.0.2", "92098a74df0072ff37d0c12ace58574d26880e522c22801437151a159392270e", [:mix], [], "hexpm", "fd12a8db2021036ce12a309f26f564ec367373265b53e25403f0ee697380f1b8"}, - "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, - "nx": {:hex, :nx, "0.5.3", "6ad5534f9b82429dafa12329952708c2fdd6ab01b306e86333fdea72383147ee", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "d1072fc4423809ed09beb729e73c200ce177ddecac425d9eb6fba643669623ec"}, + "makeup_erlang": {:hex, :makeup_erlang, "0.1.3", "d684f4bac8690e70b06eb52dad65d26de2eefa44cd19d64a8095e1417df7c8fd", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "b78dc853d2e670ff6390b605d807263bf606da3c82be37f9d7f68635bd886fc9"}, + "metrics": {:hex, :metrics, "1.0.1", "25f094dea2cda98213cecc3aeff09e940299d950904393b2a29d191c346a8486", [:rebar3], [], "hexpm", "69b09adddc4f74a40716ae54d140f93beb0fb8978d8636eaded0c31b6f099f16"}, + "mimerl": {:hex, :mimerl, "1.2.0", "67e2d3f571088d5cfd3e550c383094b47159f3eee8ffa08e64106cdf5e981be3", [:rebar3], [], "hexpm", "f278585650aa581986264638ebf698f8bb19df297f66ad91b18910dfc6e19323"}, + "nimble_csv": {:hex, :nimble_csv, "1.2.0", "4e26385d260c61eba9d4412c71cea34421f296d5353f914afe3f2e71cce97722", [:mix], [], "hexpm", "d0628117fcc2148178b034044c55359b26966c6eaa8e2ce15777be3bbc91b12a"}, + "nimble_options": {:hex, :nimble_options, "1.1.0", "3b31a57ede9cb1502071fade751ab0c7b8dbe75a9a4c2b5bbb0943a690b63172", [:mix], [], "hexpm", "8bbbb3941af3ca9acc7835f5655ea062111c9c27bcac53e004460dfd19008a99"}, + "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, + "nx": {:hex, :nx, "0.6.4", "948d9f42f81e63fc901d243ac0a985c8bb87358be62e27826cfd67f58bc640af", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "bb9c2e2e3545b5eb4739d69046a988daaa212d127dba7d97801c291616aff6d6"}, + "parse_trans": {:hex, :parse_trans, "3.4.1", "6e6aa8167cb44cc8f39441d05193be6e6f4e7c2946cb2759f015f8c56b76e5ff", [:rebar3], [], "hexpm", "620a406ce75dada827b82e453c19cf06776be266f5a67cff34e1ef2cbb60e49a"}, + "scidata": {:hex, :scidata, "0.1.11", "fe3358bac7d740374b4f2a7eff6a1cb02e5ee7f87f7cdb1e8648ad93c533165f", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.1", [hex: :nimble_csv, repo: "hexpm", optional: false]}, {:stb_image, "~> 0.4", [hex: :stb_image, repo: "hexpm", optional: true]}], "hexpm", "90873337a9d5fe880d640517efa93d3c07e46c8ba436de44117f581800549f93"}, + "ssl_verify_fun": {:hex, :ssl_verify_fun, "1.1.7", "354c321cf377240c7b8716899e182ce4890c5938111a1296add3ec74cf1715df", [:make, :mix, :rebar3], [], "hexpm", "fe4c190e8f37401d30167c8c405eda19469f34577987c76dde613e838bbc67f8"}, + "table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, + "unicode_util_compat": {:hex, :unicode_util_compat, "0.7.0", "bc84380c9ab48177092f43ac89e4dfa2c6d62b40b8bd132b1059ecc7232f9a78", [:rebar3], [], "hexpm", "25eee6d67df61960cf6a794239566599b09e17e668d3700247bc498638152521"}, + "vega_lite": {:hex, :vega_lite, "0.1.8", "7f6119126ecaf4bc2c1854084370d7091424f5cce4795fbac044eee9963f0752", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "6c8a9271f850612dd8a90de8d1ebd433590ed07ffef76fc2397c240dc04d3fdc"}, } diff --git a/notebooks/plotting.livemd b/notebooks/plotting.livemd new file mode 100644 index 0000000..fb4c76c --- /dev/null +++ b/notebooks/plotting.livemd @@ -0,0 +1,452 @@ +# Plotting in EXGBoost + +```elixir +# Mix.install( +# [ +# {:exgboost, "~> 0.5", env: :dev} +# ] +# ) + +Mix.install([ + {:exgboost, path: "/Users/andres/Documents/exgboost"}, + {:nx, "~> 0.5"}, + {:scidata, "~> 0.1"}, + {:scholar, "~> 0.1"}, + {:kino_vega_lite, "~> 0.1"} +]) + +# This assumed you launch this livebook from its location in the exgboost/notebooks folder +``` + +## Introduction + +Much of the utility from decision trees come from their intuitiveness and ability to inform dcisions outside of the confines of a black-box model. A decision tree can be easily translated to a series of actions that can be taken on behalf of the stakeholder to achieve the desired outcome. This makes them especially useful in business decisions, where people might still want to have the final say but be as informed as possible. Additionally, tabular data is still quite popular in the business domain, which conforms to the required input for decision trees. + +Decision trees can be used for both regression and classification tasks, but classification tends to be what is most associated with decision trees. + + + +This notebook will go over some of the details of the `EXGBoost.Plotting` module, including using preconfiged styles, custom styling, as well as customizing the entire vidualization. + +## Plotting APIs + +There are 2 main APIs exposed to control plotting in `EXGBoost`: + +* Top-level API (`EXGBoost.plot_tree/2`) + + * Using predefined styles + * Defining custom styles + * Mix of the first 2 + +* `EXBoost.Plotting` module API + + * Use the Vega `data` spec defined in `EXGBoost.get_data_spec/2` + * Define your own Vega spec using the data from either `EXGBoost.Plotting.to_tabular/1` or some other means + + We will walk through each of these in detail. + +Regardless of which API you choose to use, it is helpful to understand how the plotting module works (althought the higher-level API you choose to work with the less important it becomes). + +## Implementation Details + +The plotting functionality provided in `EXGBoost` is powered by the [Vega](https://vega.github.io/vega/) JavaScript library and the Elixir [`VegaLite`](https://hexdocs.pm/vega_lite/VegaLite.html) library which provides the piping to interop with the JavaScript libraries. **We do not actually much use the Elixir API provided by the Elixir VegaLite library. It is mainly used for the purposes of rendering.** + +Vega is a plotting library built on top of the very powerful [D3](https://d3js.org/) JavaScript library. Vega visualizations are defined according to the respective JSON Schema specification. Vega-Lite offers a [reduced schema](https://vega.github.io/schema/vega-lite/v5.json) compared to the [full Vega spec](https://vega.github.io/schema/vega/v5.json). `EXGBoost.Plotting` leverages several transforms which are not available in the reduced Vega-Lite schema, which is the reason for targeting the lower-level API. + +For these reasons, unfortunately we could not just implement plotting for `EXGBoost` as a composable Vega-Lite pipeline. This makes working synamically with the spec a bit more unwieldly, but much care was taken to still make the high-level plotting API extensible, and if needed you can go straight to defining your own JSON spec. + +## Setup Data + +We will still be using the Iris dataset for this notebook, but if you want more details about the process of training and evaluating a model please check out the `Iris Classification with Gradient Boosting` notebook. + +So let's proceed by setting up the Iris dataset. + +```elixir +{x, y} = Scidata.Iris.download() +data = Enum.zip(x, y) |> Enum.shuffle() +{train, test} = Enum.split(data, ceil(length(data) * 0.8)) +{x_train, y_train} = Enum.unzip(train) +{x_test, y_test} = Enum.unzip(test) + +x_train = Nx.tensor(x_train) +y_train = Nx.tensor(y_train) + +x_test = Nx.tensor(x_test) +y_test = Nx.tensor(y_test) +``` + +## Train Your Booster + +Now go ahead and train your booster. We will use `early_stopping_rounds: 1` because we're not interested in the accuracy of the booster for this demonstration (*Note that we need to set `evals` to use early stopping*). + +You will notice that `EXGBoost` also provides an implementation for `Kino.Render` so that `EXGBoost.Booster`s are rendered as a plot by default. + +```elixir +booster = + EXGBoost.train( + x_train, + y_train, + num_class: 3, + objective: :multi_softprob, + num_boost_rounds: 10, + evals: [{x_train, y_train, "training"}], + verbose_eval: false, + early_stopping_rounds: 1 + ) +``` + +You'll notice that the plot doesn't display any labels to the features in the splits, and instead only shows features labelled as "f2" etc. If you provide feature labels during training, your plot will show the splits using the feature labels. + +```elixir +booster = + EXGBoost.train(x_train, y_train, + num_class: 3, + objective: :multi_softprob, + num_boost_rounds: 10, + evals: [{x_train, y_train, "training"}], + verbose_eval: false, + feature_name: ["sepal length", "sepal width", "petal length", "petal width"], + early_stopping_rounds: 1 + ) +``` + +## Top-Level API + +`EXGBoost.plot_tree/2` is the quickest way to customize the output of the plot. + +This API uses [Vega `Mark`s](https://vega.github.io/vega/docs/marks/) to describe the plot. Each of the following `Mark` options accepts any of the valid keys from their respective `Mark` type as described in the Vega documentation. + +**Please note that these are passed as a `Keyword`, and as such the keys must be atoms rather than strings as the Vega docs show. Valid options for this API are `camel_cased` atoms as opposed to the `pascalCased` strings the Vega docs describe, so if you wish to pass `"fontSize"` as the Vega docs show, you would instead pass it as `font_size:` in this API.** + +The plot is composed of the following parts: + +* Top-level keys: Options controlling parts of the plot outside of direct control of a `Mark`, such as `:padding`, `:autosize`, etc. Accepts any Vega top-level [top-level key](https://vega.github.io/vega/docs/specification/) in addition to several specific to this API (scuh as `:style` and `:depth`). +* `:leaves`: `Mark` specifying the leaf nodes of the tree + * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) + * `:rect`: [Rect Mark](https://vega.github.io/vega/docs/marks/rect/) +* `:splits` `Mark` specifying the split (or inner / decision) nodes of the tree + * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) + * `:rect`: [Rect Mark](https://vega.github.io/vega/docs/marks/rect/) + * `:children`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) for the child count +* `:yes` + * `:path`: [Path Mark](https://vega.github.io/vega/docs/marks/path/) + * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) +* `:no` + * `:path`: [Path Mark](https://vega.github.io/vega/docs/marks/path/) + * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) + +`EXGBoost.plot_tree/2` defaults to outputting a `VegaLite` struct. If you pass the `:path` option it will save to a file instead. + +If you want to add any marks to the underlying plot you will have to use the lower-level `EXGBoost.Plotting` API, as the top-level API is only capable of customizing these marks. + + + +### Top-Level Keys + + + +`EXGBoost` supports changing the direction of the plots through the `:rankdir` option. Avaiable directions are `[:tb, :bt, :lr, :rl]`, with top-to-bottom (`:tb`) being the default. + +```elixir +EXGBoost.plot_tree(booster, rankdir: :bt) +``` + +By default, plotting only shows one (the first) tree, but seeing as a `Booster` is really an ensemble of trees you can choose which tree to plot through the `:index` option, or set to `nil` to have a dropdown box to select the tree. + +```elixir +EXGBoost.plot_tree(booster, rankdir: :lr, index: nil) +``` + +You'll also notice that the plot is interactive, with support for scrolling, zooming, and collapsing sections of the tree. If you click on a split node you will toggle the visibility of its descendents, and the rest of the tree will fill the canvas. + +You can also use the `:depth` option to programatically set the max depth to display in the tree: + +```elixir +EXGBoost.plot_tree(booster, rankdir: :lr, index: 4, depth: 3) +``` + +One way to affect the canvas size is by controlling the padding. + +You can add padding to all side by specifying an integer for the `:padding` option + +```elixir +EXGBoost.plot_tree(booster, rankdir: :rl, index: 4, depth: 3, padding: 50) +``` + +Or specify padding for each side: + +```elixir +EXGBoost.plot_tree(booster, + rankdir: :lr, + index: 4, + depth: 3, + padding: [top: 5, bottom: 25, left: 50, right: 10] +) +``` + +You can also specify the canvas size using the `:width` and `:height` options: + +```elixir +EXGBoost.plot_tree(booster, + rankdir: :lr, + index: 4, + depth: 3, + width: 500, + height: 500 +) +``` + +But do note that changing the padding of a canvas does change the size, even if you specify the size using `:height` and `:width` + +```elixir +EXGBoost.plot_tree(booster, + rankdir: :lr, + index: 4, + depth: 3, + width: 500, + height: 500, + padding: 10 +) +``` + +You can change the dimensions of all nodes through the `:node_height` and `:node_width` options: + +```elixir +EXGBoost.plot_tree(booster, rankdir: :lr, index: 4, depth: 3, node_width: 60, node_height: 60) +``` + +Or change the space between nodes using the `:space_between` option. + +**Note that the size of the accompanying nodes and/or text will change to accomodate the new `:space_between` option while trying to maintain the canvas size.** + +```elixir +EXGBoost.plot_tree( + booster, + rankdir: :lr, + index: 4, + depth: 3, + space_between: [nodes: 200] +) +``` + +So if you want to add the space between while not changing the size of the nodes you might need to manually adjust the canvas size: + +```elixir +EXGBoost.plot_tree( + booster, + rankdir: :lr, + index: 4, + depth: 3, + space_between: [nodes: 200], + height: 800 +) +``` + +```elixir +EXGBoost.plot_tree( + booster, + rankdir: :lr, + index: 4, + depth: 3, + space_between: [levels: 200] +) +``` + +### Mark Options + +The options controlling the appearance of individual marks all conform to a similar API. You can refer to the options and pre-defined defaults for a subset of the allowed options, but you can also pass other options so long as they are allowed by the Vega Mark spec (as defined [here](#cell-y5oxrrri4daa6xt5)) + +```elixir +EXGBoost.plot_tree( + booster, + rankdir: :bt, + index: 4, + depth: 3, + space_between: [levels: 200], + yes: [ + text: [font_size: 18, fill: :teal] + ], + no: [ + text: [font_size: 20] + ], + node_width: 100 +) +``` + +Most marks accept an `:opacity` option that you can use to effectively hide the mark: + +```elixir +EXGBoost.plot_tree( + booster, + rankdir: :lr, + index: 4, + depth: 3, + splits: [ + text: [opacity: 0], + rect: [opacity: 0], + children: [opacity: 1] + ] +) +``` + +And `text` marks accept normal text options such as `:fill`, `:font_size`, and `:font`: + +```elixir +EXGBoost.plot_tree( + booster, + node_width: 250, + splits: [ + text: [font: "Helvetica Neue", font_size: 20, fill: "orange"] + ], + space_between: [levels: 20] +) +``` + +### Styles + +There are a set of provided pre-configured settings for the top-level API that you may optionally use. You can refer to the `EXGBoost.Plottings.Styles` docs to see a gallery of each style in action. You can specify a style with the `:style` option in `EXGBoost.plot_tree/2`. + +You can still specify custom settings along with using a style. Most styles only specify a subset of the total possible settings, but you are free to specify any other allowed keys and they will be merged with the style. Any options passed explicitly to the option **does** take precedence over the style options. + +For example, let's look at the `:solarized_dark` style: + +```elixir +EXGBoost.Plotting.solarized_dark() |> Keyword.take([:background, :height]) |> IO.inspect() +EXGBoost.plot_tree(booster, style: :solarized_dark) +``` + +You can see that it defines a background color of `#002b36` but does not restrict what the height must be. + +```elixir +EXGBoost.plot_tree(booster, style: :solarized_dark, background: "white", height: 200) +``` + +We specified both `:background` and `:height` here, and the background specified in the option supercedes the one from the style. + +You can also always get the style specification as a `Keyword` which can be passed to `EXGBoost.plot_tree/2` manually, making any needed changes yourself, like so: + +```elixir +custom_style = EXGBoost.Plotting.solarized_dark() |> Keyword.put(:background, "white") +EXGBoost.plot_tree(booster, style: custom_style) +``` + +You can also programatically check which styles are available: + +```elixir +EXGBoost.Plotting.get_styles() +``` + +### Configuration + +You can also set defaults for the top-level API using an `Application` configuration for `EXGBoost` under the `:plotting` key. Since the defaults are collected from your configuration file at compile-time, anything you set during runtime, even if you set it to the Application environment, will not be registered as defaults. + +For example, if you just want to change the default pre-configured style you can do: + + + +```elixir +Mix.install([ + {:exgboost, path: Path.join(__DIR__, ".."), env: :dev}, +], + config: + [ + exgboost: [ + plotting: [ + style: :solarized_dark, + ]] + ], + lockfile: :exgboost) +``` + +You can also make one-off changes to any of the settings with this method. In effect, this turns into a default custom style. **Just make sure to set `style: nil` to ensure that the `style` option doesn't supercede any of your settings.** Here's an example of that: + + + +```elixir + default_style = + [ + style: nil, + background: "#3f3f3f", + leaves: [ + # Foreground + text: [fill: "#dcdccc", font_size: 12, font_style: "normal", font_weight: "normal"], + # Comment + rect: [fill: "#7f9f7f", stroke: "#7f9f7f"] + ], + splits: [ + # Foreground + text: [fill: "#dcdccc", font_size: 12, font_style: "normal", font_weight: "bold"], + # Comment + rect: [fill: "#7f9f7f", stroke: "#7f9f7f"], + # Selection + children: [fill: "#2b2b2b", stroke: "#2b2b2b"] + ], + yes: [ + # Green + text: [fill: "#7f9f7f"], + # Selection + path: [stroke: "#2b2b2b"] + ], + no: [ + # Red + text: [fill: "#cc9393"], + # Selection + path: [stroke: "#2b2b2b"] + ] + ] + +Mix.install([ + {:exgboost, path: Path.join(__DIR__, ".."), env: :dev}, +], +config: + [ + exgboost: [ + plotting: default_style, + ] + ] +) +``` + +**NOTE: When you specify a parameter in the configuration, it is merged with the defaults which is different from runtime behavior.** + +At any point, you can check what your default settings are by using `EXGBoost.Plotting.get_defaults/0` + +```elixir +EXGBoost.Plotting.get_defaults() +``` + +## Low-Level API + +If you find yourself needing more granular control over your plots, you can reach towards the `EXGBoost.Plotting` module. This module houses the `EXGBoost.Plotting.plot/2` function, which is what is used under the hood from the `EXGBoost.plot_tree/2` top-level API. This module also has the `get_data_spec/2` function, as well as the `to_tabular/1` function, both of which can be used to specify your own Vega specification. Lastly, the module also houses all of the pre-configured styles, which are 0-arity functions which output the `Keyword`s containing their respective style's options that can be passed to the plotting APIs. + +Let's briefly go over the `to_tabular/1` and `get_data_spec/2` functions: + + + +The `to_tabular/1` function is used to convert a `Booster`, which is formatted as a tree structure, to a tabular format which can be ingested specifically by the [Vega Stratify transform](https://vega.github.io/vega/docs/transforms/stratify/). It returns a list of "nodes", which are just `Map`s with info about each node in the tree. + +```elixir +EXGBoost.Plotting.to_tabular(booster) |> hd +``` + +You can use this function if you want to have complete control over the visualization, and just want a bit of a head start with respect to data transformation for converting the `Booster` into a more digestible format. + + + +The `get_data_source/2` function is used if you want to use the provided [Vega data specification](https://vega.github.io/vega/docs/data/). This is for those who want to only focus on implementing your own [Vega Marks](https://vega.github.io/vega/docs/marks/), and want to leverage the data transformation pipeline that powers the top-level API. + +The data transformation used is the following pipeline: + +`to_tabular/1` -> [Filter](https://vega.github.io/vega/docs/transforms/filter/) (by tree index) -> [Stratify](https://vega.github.io/vega/docs/transforms/stratify/) -> [Tree](https://vega.github.io/vega/docs/transforms/tree/) + +```elixir +EXGBoost.Plotting.get_data_spec(booster, rankdir: :bt) +``` + +The Vega fields which are not included with `get_data_spec/2` and are included in `plot/2` are: + +* [Marks](https://vega.github.io/vega/docs/marks/) +* [Scales](https://vega.github.io/vega/docs/scales/) +* [Signals](https://vega.github.io/vega/docs/signals/) + +You can make a completely valid plot using only the Data from `get_data_specs/2` and adding the marks you need. diff --git a/test/data/model.json b/test/data/model.json new file mode 100644 index 0000000..6f3a378 Binary files /dev/null and b/test/data/model.json differ diff --git a/test/exgboost_test.exs b/test/exgboost_test.exs index 60ffadb..cf96554 100644 --- a/test/exgboost_test.exs +++ b/test/exgboost_test.exs @@ -151,6 +151,24 @@ defmodule EXGBoostTest do refute is_nil(booster.best_iteration) refute is_nil(booster.best_score) + + # If no eval metric is provided, the default metric is used. If the default + # metric is disabled, an error is raised. + assert_raise ArgumentError, + fn -> + ExUnit.CaptureIO.with_io(fn -> + EXGBoost.train(x, y, + disable_default_eval_metric: true, + num_boost_rounds: 10, + early_stopping_rounds: 1, + evals: [{x, y, "validation"}], + tree_method: :hist + ) + end) + end + + refute is_nil(booster.best_iteration) + refute is_nil(booster.best_score) end test "eval with multiple metrics", context do diff --git a/test/model.json b/test/model.json new file mode 100644 index 0000000..a3575a8 Binary files /dev/null and b/test/model.json differ