From 630ac241dbf3085a6d0168cf052df282252d810f Mon Sep 17 00:00:00 2001 From: acalejos Date: Sat, 2 Dec 2023 12:01:56 -0500 Subject: [PATCH 01/21] Plotting WIP --- lib/exgboost.ex | 18 +++ lib/exgboost/plotting.ex | 318 +++++++++++++++++++++++++++++++++++++++ mix.exs | 6 +- mix.lock | 23 +++ 4 files changed, 364 insertions(+), 1 deletion(-) create mode 100644 lib/exgboost/plotting.ex diff --git a/lib/exgboost.ex b/lib/exgboost.ex index 67f9636..d893d2b 100644 --- a/lib/exgboost.ex +++ b/lib/exgboost.ex @@ -121,6 +121,7 @@ defmodule EXGBoost do alias EXGBoost.DMatrix alias EXGBoost.ProxyDMatrix alias EXGBoost.Training + alias EXGBoost.Plotting @doc """ Check the build information of the xgboost library. @@ -549,4 +550,21 @@ defmodule EXGBoost do def load_weights(buffer) do EXGBoost.Booster.load(buffer, deserialize: :weights, from: :buffer) end + + @doc """ + Plot a tree from a Booster model and save it to a file. + + ## Options + * `:format` - the format to export the graphic as, must be either of: `:json`, `:html`, `:png`, `:svg`, `:pdf`. By default the format is inferred from the file extension. + * `:local_npm_prefix` - a relative path pointing to a local npm project directory where the necessary npm packages are installed. For instance, in Phoenix projects you may want to pass local_npm_prefix: "assets". By default the npm packages are searched for in the current directory and globally. + """ + def plot_tree(booster, path, opts \\ []) when is_binary(path) do + vega = Plotting.to_vega(booster) + + if Keyword.get(opts, :path, nil) == nil do + :ok = VegaLite.Export.save!(vega, path, opts) + end + + vega + end end diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex new file mode 100644 index 0000000..9e19baf --- /dev/null +++ b/lib/exgboost/plotting.ex @@ -0,0 +1,318 @@ +defmodule EXGBoost.Plotting.Schema do + defmacro __before_compile__(_env) do + HTTPoison.start() + + schema = HTTPoison.get!("https://vega.github.io/schema/vega/v5.json").body + + quote do + require Exonerate + Exonerate.function_from_string(:def, :validate, unquote(schema), dump: true, metadata: true) + end + |> Code.compile_quoted() + end +end + +defmodule EXGBoost.Plotting do + @before_compile EXGBoost.Plotting.Schema + + @plotting_params [ + height: [ + description: "Height of the plot in pixels", + type: :pos_integer, + default: 400 + ], + width: [ + description: "Width of the plot in pixels", + type: :pos_integer, + default: 600 + ], + padding: [ + description: + "The padding in pixels to add around the visualization. If a number, specifies padding for all sides. If an object, the value should have the format `[left: value, right: value, top: value, bottom: value]`", + type: + {:or, + [ + :pos_integer, + keyword_list: [ + left: [type: :pos_integer], + right: [type: :pos_integer], + top: [type: :pos_integer], + bottom: [type: :pos_integer] + ] + ]}, + default: 5 + ] + ] + + defp _to_tabular(idx, node, parent) do + node = Map.put(node, "tree_id", idx) + node = if parent, do: Map.put(node, "parentid", parent), else: node + node = new_node(node) + + if Map.has_key?(node, "children") do + [ + Map.delete(node, "children") + | Enum.flat_map(Map.get(node, "children"), fn child -> + to_tabular(idx, child, Map.get(node, "nodeid")) + end) + ] + else + [node] + end + end + + defp new_node(params = %{}) do + Map.merge( + %{ + "nodeid" => nil, + "depth" => nil, + "missing" => nil, + "no" => nil, + "leaf" => nil, + "split" => nil, + "split_condition" => nil, + "yes" => nil, + "tree_id" => nil, + "parentid" => nil + }, + params + ) + end + + @doc """ + Outputs details of the tree in a tabular format which can be consumed + by plotting libraries such as Vega. Outputs as a list of maps, where + each map represents a node in the tree. + + Table columns: + - tree_id: The tree id + - nodeid: The node id + - parentid: The parent node id + - split: The split feature + - split_condition: The split condition + - yes: The node id of the left child + - no: The node id of the right child + - missing: The node id of the missing child + - depth: The depth of the node + - leaf: The leaf value if it is a leaf node + """ + def to_tabular(booster) do + booster + |> EXGBoost.Booster.get_dump(format: :json) + |> Enum.map(&Jason.decode!(&1)) + |> Enum.with_index() + |> Enum.reduce( + [], + fn {tree, index}, acc -> + acc ++ _to_tabular(index, tree, nil) + end + ) + end + + defp to_vega(booster) do + tabular = + booster + |> EXGBoost.Booster.get_dump(format: :json) + |> Enum.map(&Jason.decode!(&1)) + |> Enum.with_index() + |> Enum.reduce( + [], + fn {tree, index}, acc -> + acc ++ to_tabular(index, tree, nil) + end + ) + + spec = %{ + "$schema" => "https://vega.github.io/schema/vega/v5.json", + "description" => "A basic decision tree layout of nodes", + "width" => 800, + "height" => 600, + "padding" => 5, + "signals" => [ + %{ + "name" => "treeSelector", + "value" => 0, + "bind" => %{ + "input" => "select", + "options" => + tabular + |> Enum.reduce([], fn node, acc -> + id = Map.get(node, "tree_id") + if id in acc, do: acc, else: [id | acc] + end) + |> Enum.sort() + } + } + ], + "data" => [ + %{ + "name" => "tree", + "values" => tabular, + "transform" => [ + %{ + "type" => "filter", + "expr" => "datum.tree_id === treeSelector" + }, + %{ + "type" => "stratify", + "key" => "nodeid", + "parentKey" => "parentid" + }, + %{ + "type" => "tree", + "size" => [%{"signal" => "width"}, %{"signal" => "height"}], + "as" => ["x", "y", "depth", "children"] + } + ] + }, + %{ + "name" => "links", + "source" => "tree", + "transform" => [ + %{"type" => "treelinks"}, + %{ + "type" => "linkpath", + "shape" => "line" + } + ] + }, + %{ + "name" => "innerNodes", + "source" => "tree", + "transform" => [ + %{"type" => "filter", "expr" => "datum.leaf == null"} + ] + }, + %{ + "name" => "leafNodes", + "source" => "tree", + "transform" => [ + %{"type" => "filter", "expr" => "datum.leaf != null"} + ] + } + ], + "scales" => [ + %{ + "name" => "color", + "type" => "ordinal", + "domain" => %{"data" => "tree", "field" => "depth"}, + "range" => %{"scheme" => "category10"} + } + ], + "marks" => [ + %{ + "type" => "path", + "from" => %{"data" => "links"}, + "encode" => %{ + "update" => %{ + "path" => %{"field" => "path"}, + "stroke" => %{"value" => "#ccc"} + } + } + }, + %{ + "type" => "text", + "from" => %{"data" => "tree"}, + "encode" => %{ + "enter" => %{ + "text" => %{ + "signal" => + "datum.split ? datum.split + ' <= ' + format(datum.split_condition, '.2f') : ''" + }, + "fontSize" => %{"value" => 10}, + "baseline" => %{"value" => "middle"}, + "align" => %{"value" => "center"}, + "dx" => %{"value" => 0}, + "dy" => %{"value" => 0} + }, + "update" => %{ + "x" => %{"field" => "x"}, + "y" => %{"field" => "y"}, + "fill" => %{"value" => "black"} + } + } + }, + %{ + "type" => "symbol", + "from" => %{"data" => "innerNodes"}, + "encode" => %{ + "enter" => %{ + "fill" => %{"value" => "none"}, + "shape" => %{"value" => "circle"}, + "x" => %{"field" => "x"}, + "y" => %{"field" => "y"}, + "size" => %{"value" => 800} + }, + "update" => %{ + "fill" => %{"scale" => "color", "field" => "depth"}, + "opacity" => %{"value" => 1}, + "stroke" => %{"value" => "black"}, + "strokeWidth" => %{"value" => 1}, + "fillOpacity" => %{"value" => 0}, + "strokeOpacity" => %{"value" => 1} + } + } + }, + %{ + "type" => "rect", + "from" => %{"data" => "leafNodes"}, + "encode" => %{ + "enter" => %{ + "x" => %{"signal" => "datum.x - 20"}, + "y" => %{"signal" => "datum.y - 10"}, + "width" => %{"value" => 40}, + "height" => %{"value" => 20}, + "stroke" => %{"value" => "black"}, + "fill" => %{"value" => "transparent"} + } + } + }, + %{ + "type" => "text", + "from" => %{"data" => "tree"}, + "encode" => %{ + "enter" => %{ + "text" => %{ + "signal" => "format(datum.leaf, '.2f')" + }, + "fontSize" => %{"value" => 10}, + "baseline" => %{"value" => "middle"}, + "align" => %{"value" => "center"}, + "dx" => %{"value" => 0}, + "dy" => %{"value" => 0} + }, + "update" => %{ + "x" => %{"field" => "x"}, + "y" => %{"field" => "y"}, + "fill" => %{"value" => "black"}, + "opacity" => %{"signal" => "datum.leaf != null ? 1 : 0"} + } + } + }, + %{ + "type" => "text", + "from" => %{"data" => "links"}, + "encode" => %{ + "enter" => %{ + "x" => %{"signal" => "(datum.source.x + datum.target.x) / 2"}, + "y" => %{"signal" => "(datum.source.y + datum.target.y) / 2"}, + "text" => %{ + "signal" => "datum.source.yes === datum.target.nodeid ? 'yes' : 'no'" + }, + "align" => %{"value" => "center"}, + "baseline" => %{"value" => "middle"} + } + } + } + ] + } + + Jason.encode!(spec) |> VegaLite.from_json() + end +end + +defimpl Kino.Render, for: EXGBoost.Booster do + def to_livebook(booster) do + Kino.Render.to_livebook(EXGBoost.Plotting.to_vega(booster)) + end +end diff --git a/mix.exs b/mix.exs index 9e33686..d31685f 100644 --- a/mix.exs +++ b/mix.exs @@ -48,7 +48,11 @@ defmodule EXGBoost.MixProject do {:jason, "~> 1.3"}, {:ex_doc, "~> 0.29.0", only: :docs}, {:cc_precompiler, "~> 0.1.0", runtime: false}, - {:exterval, "0.1.0"} + {:exterval, "0.1.0"}, + {:exonerate, "~> 1.1", runtime: false}, + {:httpoison, "~> 2.0", runtime: false}, + {:vega_lite, "~> 0.1"}, + {:kino, "~> 0.11"} ] end diff --git a/mix.lock b/mix.lock index 37368c9..c313017 100644 --- a/mix.lock +++ b/mix.lock @@ -1,16 +1,39 @@ %{ + "castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"}, "cc_precompiler": {:hex, :cc_precompiler, "0.1.7", "77de20ac77f0e53f20ca82c563520af0237c301a1ec3ab3bc598e8a96c7ee5d9", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "2768b28bf3c2b4f788c995576b39b8cb5d47eb788526d93bd52206c1d8bf4b75"}, + "certifi": {:hex, :certifi, "2.12.0", "2d1cca2ec95f59643862af91f001478c9863c2ac9cb6e2f89780bfd8de987329", [:rebar3], [], "hexpm", "ee68d85df22e554040cdb4be100f33873ac6051387baf6a8f6ce82272340ff1c"}, "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, "earmark_parser": {:hex, :earmark_parser, "1.4.32", "fa739a0ecfa34493de19426681b23f6814573faee95dfd4b4aafe15a7b5b32c6", [:mix], [], "hexpm", "b8b0dd77d60373e77a3d7e8afa598f325e49e8663a51bcc2b88ef41838cca755"}, "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, "ex_doc": {:hex, :ex_doc, "0.29.4", "6257ecbb20c7396b1fe5accd55b7b0d23f44b6aa18017b415cb4c2b91d997729", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "2c6699a737ae46cb61e4ed012af931b57b699643b24dabe2400a8168414bc4f5"}, + "exonerate": {:hex, :exonerate, "1.1.3", "96c3249f0358567151d595a49d612849a7a57930575a1eb781c178f2e84b1559", [:mix], [{:finch, "~> 0.15", [hex: :finch, repo: "hexpm", optional: true]}, {:idna, "~> 6.1.1", [hex: :idna, repo: "hexpm", optional: true]}, {:jason, "~> 1.4.0", [hex: :jason, repo: "hexpm", optional: false]}, {:json_ptr, "~> 1.0", [hex: :json_ptr, repo: "hexpm", optional: false]}, {:match_spec, "~> 0.3.1", [hex: :match_spec, repo: "hexpm", optional: false]}, {:pegasus, "~> 0.2.2", [hex: :pegasus, repo: "hexpm", optional: true]}, {:req, "~> 0.3", [hex: :req, repo: "hexpm", optional: true]}, {:yaml_elixir, "~> 2.7", [hex: :yaml_elixir, repo: "hexpm", optional: true]}], "hexpm", "cb750e303ae79987153eecb474ccd011f45e3dea7c4f988d8286facdf56db1d2"}, "exterval": {:hex, :exterval, "0.1.0", "d35ec43b0f260239f859665137fac0974f1c6a8d50bf8d52b5999c87c67c63e5", [:mix], [], "hexpm", "8ed444558a501deec6563230e3124cdf242c413a95eb9ca9f39de024ad779d7f"}, + "finch": {:hex, :finch, "0.16.0", "40733f02c89f94a112518071c0a91fe86069560f5dbdb39f9150042f44dcfb1a", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:mint, "~> 1.3", [hex: :mint, repo: "hexpm", optional: false]}, {:nimble_options, "~> 0.4 or ~> 1.0", [hex: :nimble_options, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 0.2.6 or ~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "f660174c4d519e5fec629016054d60edd822cdfe2b7270836739ac2f97735ec5"}, + "fss": {:hex, :fss, "0.1.1", "9db2344dbbb5d555ce442ac7c2f82dd975b605b50d169314a20f08ed21e08642", [:mix], [], "hexpm", "78ad5955c7919c3764065b21144913df7515d52e228c09427a004afe9c1a16b0"}, + "hackney": {:hex, :hackney, "1.20.1", "8d97aec62ddddd757d128bfd1df6c5861093419f8f7a4223823537bad5d064e2", [:rebar3], [{:certifi, "~> 2.12.0", [hex: :certifi, repo: "hexpm", optional: false]}, {:idna, "~> 6.1.0", [hex: :idna, repo: "hexpm", optional: false]}, {:metrics, "~> 1.0.0", [hex: :metrics, repo: "hexpm", optional: false]}, {:mimerl, "~> 1.1", [hex: :mimerl, repo: "hexpm", optional: false]}, {:parse_trans, "3.4.1", [hex: :parse_trans, repo: "hexpm", optional: false]}, {:ssl_verify_fun, "~> 1.1.0", [hex: :ssl_verify_fun, repo: "hexpm", optional: false]}, {:unicode_util_compat, "~> 0.7.0", [hex: :unicode_util_compat, repo: "hexpm", optional: false]}], "hexpm", "fe9094e5f1a2a2c0a7d10918fee36bfec0ec2a979994cff8cfe8058cd9af38e3"}, + "hpax": {:hex, :hpax, "0.1.2", "09a75600d9d8bbd064cdd741f21fc06fc1f4cf3d0fcc335e5aa19be1a7235c84", [:mix], [], "hexpm", "2c87843d5a23f5f16748ebe77969880e29809580efdaccd615cd3bed628a8c13"}, + "httpoison": {:hex, :httpoison, "2.2.1", "87b7ed6d95db0389f7df02779644171d7319d319178f6680438167d7b69b1f3d", [:mix], [{:hackney, "~> 1.17", [hex: :hackney, repo: "hexpm", optional: false]}], "hexpm", "51364e6d2f429d80e14fe4b5f8e39719cacd03eb3f9a9286e61e216feac2d2df"}, + "idna": {:hex, :idna, "6.1.1", "8a63070e9f7d0c62eb9d9fcb360a7de382448200fbbd1b106cc96d3d8099df8d", [:rebar3], [{:unicode_util_compat, "~> 0.7.0", [hex: :unicode_util_compat, repo: "hexpm", optional: false]}], "hexpm", "92376eb7894412ed19ac475e4a86f7b413c1b9fbb5bd16dccd57934157944cea"}, "jason": {:hex, :jason, "1.4.0", "e855647bc964a44e2f67df589ccf49105ae039d4179db7f6271dfd3843dc27e6", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "79a3791085b2a0f743ca04cec0f7be26443738779d09302e01318f97bdb82121"}, + "json_ptr": {:hex, :json_ptr, "1.2.0", "07a757d1b0a86b7fd73f5a5b26d4d41c5bc7d5be4a3e3511d7458293ce70b8bb", [:mix], [{:jason, "~> 1.4.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "e58704ac304cbf3832c0ac161e76479e7b05f75427991ddd57e19b307ae4aa05"}, + "kino": {:hex, :kino, "0.11.3", "c48a3f8bcad5ec103884aee119dc01594c0ead7c755061853147357c01885f2c", [:mix], [{:fss, "~> 0.1.0", [hex: :fss, repo: "hexpm", optional: false]}, {:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "027e1f634ce175e9092d9bfc632c911ca50ad735eecc77770bfaf4b4a8b2932f"}, "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, + "match_spec": {:hex, :match_spec, "0.3.1", "9fb2b9313dbd2c5aa6210a8c11eea94e0531a0982e463d9467445e8615081af0", [:mix], [], "hexpm", "225e90aa02c8023b12cf47e33c8a7307ac88b31f461b43d93e040f4641704557"}, + "metrics": {:hex, :metrics, "1.0.1", "25f094dea2cda98213cecc3aeff09e940299d950904393b2a29d191c346a8486", [:rebar3], [], "hexpm", "69b09adddc4f74a40716ae54d140f93beb0fb8978d8636eaded0c31b6f099f16"}, + "mime": {:hex, :mime, "2.0.5", "dc34c8efd439abe6ae0343edbb8556f4d63f178594894720607772a041b04b02", [:mix], [], "hexpm", "da0d64a365c45bc9935cc5c8a7fc5e49a0e0f9932a761c55d6c52b142780a05c"}, + "mimerl": {:hex, :mimerl, "1.2.0", "67e2d3f571088d5cfd3e550c383094b47159f3eee8ffa08e64106cdf5e981be3", [:rebar3], [], "hexpm", "f278585650aa581986264638ebf698f8bb19df297f66ad91b18910dfc6e19323"}, + "mint": {:hex, :mint, "1.5.1", "8db5239e56738552d85af398798c80648db0e90f343c8469f6c6d8898944fb6f", [:mix], [{:castore, "~> 0.1.0 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:hpax, "~> 0.1.1", [hex: :hpax, repo: "hexpm", optional: false]}], "hexpm", "4a63e1e76a7c3956abd2c72f370a0d0aecddc3976dea5c27eccbecfa5e7d5b1e"}, "nimble_options": {:hex, :nimble_options, "1.0.2", "92098a74df0072ff37d0c12ace58574d26880e522c22801437151a159392270e", [:mix], [], "hexpm", "fd12a8db2021036ce12a309f26f564ec367373265b53e25403f0ee697380f1b8"}, "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, + "nimble_pool": {:hex, :nimble_pool, "1.0.0", "5eb82705d138f4dd4423f69ceb19ac667b3b492ae570c9f5c900bb3d2f50a847", [:mix], [], "hexpm", "80be3b882d2d351882256087078e1b1952a28bf98d0a287be87e4a24a710b67a"}, "nx": {:hex, :nx, "0.5.3", "6ad5534f9b82429dafa12329952708c2fdd6ab01b306e86333fdea72383147ee", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "d1072fc4423809ed09beb729e73c200ce177ddecac425d9eb6fba643669623ec"}, + "parse_trans": {:hex, :parse_trans, "3.4.1", "6e6aa8167cb44cc8f39441d05193be6e6f4e7c2946cb2759f015f8c56b76e5ff", [:rebar3], [], "hexpm", "620a406ce75dada827b82e453c19cf06776be266f5a67cff34e1ef2cbb60e49a"}, + "req": {:hex, :req, "0.4.5", "2071bbedd280f107b9e33e1ddff2beb3991ec1ae06caa2cca2ab756393d8aca5", [:mix], [{:brotli, "~> 0.3.1", [hex: :brotli, repo: "hexpm", optional: true]}, {:ezstd, "~> 1.0", [hex: :ezstd, repo: "hexpm", optional: true]}, {:finch, "~> 0.9", [hex: :finch, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:mime, "~> 1.6 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.0", [hex: :nimble_csv, repo: "hexpm", optional: true]}, {:plug, "~> 1.0", [hex: :plug, repo: "hexpm", optional: true]}], "hexpm", "dd23e9c7303ddeb2dee09ff11ad8102cca019e38394456f265fb7b9655c64dd8"}, + "ssl_verify_fun": {:hex, :ssl_verify_fun, "1.1.7", "354c321cf377240c7b8716899e182ce4890c5938111a1296add3ec74cf1715df", [:make, :mix, :rebar3], [], "hexpm", "fe4c190e8f37401d30167c8c405eda19469f34577987c76dde613e838bbc67f8"}, + "table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, + "unicode_util_compat": {:hex, :unicode_util_compat, "0.7.0", "bc84380c9ab48177092f43ac89e4dfa2c6d62b40b8bd132b1059ecc7232f9a78", [:rebar3], [], "hexpm", "25eee6d67df61960cf6a794239566599b09e17e668d3700247bc498638152521"}, + "vega_lite": {:hex, :vega_lite, "0.1.8", "7f6119126ecaf4bc2c1854084370d7091424f5cce4795fbac044eee9963f0752", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "6c8a9271f850612dd8a90de8d1ebd433590ed07ffef76fc2397c240dc04d3fdc"}, } From d787808eed1ea50e6187afef2c3bf67212597ed6 Mon Sep 17 00:00:00 2001 From: acalejos Date: Sun, 3 Dec 2023 12:45:57 -0500 Subject: [PATCH 02/21] use ex json schema --- lib/exgboost/plotting.ex | 489 ++++++++++++++++++++++++++------------- mix.exs | 2 +- mix.lock | 2 + 3 files changed, 331 insertions(+), 162 deletions(-) diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex index 9e19baf..dc1a972 100644 --- a/lib/exgboost/plotting.ex +++ b/lib/exgboost/plotting.ex @@ -1,33 +1,99 @@ -defmodule EXGBoost.Plotting.Schema do - defmacro __before_compile__(_env) do - HTTPoison.start() +# defmodule EXGBoost.Plotting.Schema do +# @moduledoc false +# Code.ensure_compiled!(ExJsonSchema) - schema = HTTPoison.get!("https://vega.github.io/schema/vega/v5.json").body +# defmacro __before_compile__(_env) do +# HTTPoison.start() - quote do - require Exonerate - Exonerate.function_from_string(:def, :validate, unquote(schema), dump: true, metadata: true) - end - |> Code.compile_quoted() - end -end +# schema = +# HTTPoison.get!("https://vega.github.io/schema/vega/v5.json").body +# |> Jason.decode!() +# |> ExJsonSchema.Schema.resolve() +# |> Macro.escape() + +# quote do +# def get_schema(), do: unquote(schema) +# end +# end +# end defmodule EXGBoost.Plotting do - @before_compile EXGBoost.Plotting.Schema + @moduledoc """ + Functions for plotting EXGBoost `Booster` models using [Vega](https://vega.github.io/vega/) + + Fundamentally, all this module does is convert a `Booster` into a format that can be + ingested by Vega, and apply some default configuations that only account for a subset of the configurations + that can be set by a Vega spec directly. The functions provided in this module are designed to have opinionated + defaults that can be used to quickly visualize a model, but the full power of Vega is available by using the + `to_tabular/1` function to convert the model into a tabular format, and then using the `to_vega/2` function + to convert the tabular format into a Vega specification. + + ## Default Vega Specification + + The default Vega specification is designed to be a good starting point for visualizing a model, but it is + possible to customize the specification by passing in a map of Vega properties to the `to_vega/2` function. + Refer to `Custom Vega Specifications` for more details on how to do this. + + By default, the Vega specification includes the following entities to use for rendering the model: + * `:width` - The width of the plot in pixels + * `:height` - The height of the plot in pixels + * `:padding` - The padding in pixels to add around the visualization. If a number, specifies padding for all sides. If an object, the value should have the format `[left: value, right: value, top: value, bottom: value]` + * `:leafs` - Specifies characteristics of leaf nodes + * `:inner_nodes` - Specifies characteristics of inner nodes + * `:links` - Specifies characteristics of links between nodes + + ## Custom Vega Specifications + + The default Vega specification is designed to be a good starting point for visualizing a model, but it is + possible to customize the specification by passing in a map of Vega properties to the `to_vega/2` function. + You can find the full list of Vega properties [here](https://vega.github.io/vega/docs/specification/). + + It is suggested that you use the data attributes provided by the default specification as a starting point, since they + provide the necessary data transformation to convert the model into a tree structure that can be visualized by Vega. + If you would like to customize the default specification, you can use `EXGBoost.Plotting.to_vega/1` to get the default + specification, and then modify it as needed. + + Once you have a custom specification, you can pass it to `VegaLite.from_json/1` to create a new `VegaLite` struct, after which + you can use the functions provided by the `VegaLite` module to render the model. + + ## Specification Validation + You can optionally validate your specification against the Vega schema by passing the `validate: true` option to `to_vega/2`. + This will raise an error if the specification is invalid. This is useful if you are creating a custom specification and want + to ensure that it is valid. Note that this will only validate the specification against the Vega schema, and not against the + VegaLite schema. This requires the [`ex_json_schema`] package to be installed. + + ## Livebook Integration + + This module also provides a `Kino.Render` implementation for `EXGBoost.Booster` which allows + models to be rendered directly in Livebook. This is done by converting the model into a Vega specification + and then using the `Kino.Render` implementation for Elixir's [`VegaLite`](https://hexdocs.pm/vega_lite/VegaLite.html) API + to render the model. + + + . The Vega specification is then passed to [VegaLite](https://hexdocs.pm/vega_lite/readme.html) + + + """ + + HTTPoison.start() + + @schema HTTPoison.get!("https://vega.github.io/schema/vega/v5.json").body + |> Jason.decode!() + |> ExJsonSchema.Schema.resolve() @plotting_params [ height: [ - description: "Height of the plot in pixels", + doc: "Height of the plot in pixels", type: :pos_integer, default: 400 ], width: [ - description: "Width of the plot in pixels", + doc: "Width of the plot in pixels", type: :pos_integer, default: 600 ], padding: [ - description: + doc: "The padding in pixels to add around the visualization. If a number, specifies padding for all sides. If an object, the value should have the format `[left: value, right: value, top: value, bottom: value]`", type: {:or, @@ -41,9 +107,59 @@ defmodule EXGBoost.Plotting do ] ]}, default: 5 + ], + leafs: [ + doc: "Specifies characteristics of leaf nodes", + type: :keyword_list, + default: [fill: "#ccc", stroke: "black", strokeWidth: 1] + ], + inner_nodes: [ + doc: "Specifies characteristics of inner nodes", + type: :boolean, + default: true + ], + links: [ + doc: "Specifies characteristics of links between nodes", + type: :keyword_list, + default: [stroke: "#ccc", strokeWidth: 1] + ], + validate: [ + doc: "Whether to validate the Vega specification against the Vega schema", + type: :boolean, + default: true + ], + index: [ + doc: + "The zero-indexed index of the tree to plot. If `nil`, plots all trees using a dropdown selector to switch between trees", + type: {:or, [nil, :pos_integer]}, + default: 0 ] ] + @plotting_schema NimbleOptions.new!(@plotting_params) + + def get_schema(), do: @schema + + defp validate_spec(spec) do + unless Code.ensure_loaded?(EXJsonSchema) do + raise( + RuntimeError, + "The `ex_json_schema` package must be installed to validate Vega specifications. Please install it by running `mix deps.get`" + ) + end + + case ExJsonSchema.Validator.validate(get_schema(), spec) do + :ok -> + spec + + {:error, errors} -> + raise( + ArgumentError, + "Invalid Vega specification: #{inspect(errors)}" + ) + end + end + defp _to_tabular(idx, node, parent) do node = Map.put(node, "tree_id", idx) node = if parent, do: Map.put(node, "parentid", parent), else: node @@ -53,7 +169,7 @@ defmodule EXGBoost.Plotting do [ Map.delete(node, "children") | Enum.flat_map(Map.get(node, "children"), fn child -> - to_tabular(idx, child, Map.get(node, "nodeid")) + _to_tabular(idx, child, Map.get(node, "nodeid")) end) ] else @@ -109,50 +225,20 @@ defmodule EXGBoost.Plotting do ) end - defp to_vega(booster) do - tabular = - booster - |> EXGBoost.Booster.get_dump(format: :json) - |> Enum.map(&Jason.decode!(&1)) - |> Enum.with_index() - |> Enum.reduce( - [], - fn {tree, index}, acc -> - acc ++ to_tabular(index, tree, nil) - end - ) + @doc """ + Generates the necessary data transformation to convert the model into a tree structure that can be visualized by Vega. - spec = %{ + This function is useful if you want to create a custom Vega specification, but still want to use the data transformation + provided by the default specification. + """ + def get_data_spec(booster) do + %{ "$schema" => "https://vega.github.io/schema/vega/v5.json", - "description" => "A basic decision tree layout of nodes", - "width" => 800, - "height" => 600, - "padding" => 5, - "signals" => [ - %{ - "name" => "treeSelector", - "value" => 0, - "bind" => %{ - "input" => "select", - "options" => - tabular - |> Enum.reduce([], fn node, acc -> - id = Map.get(node, "tree_id") - if id in acc, do: acc, else: [id | acc] - end) - |> Enum.sort() - } - } - ], "data" => [ %{ "name" => "tree", - "values" => tabular, + "values" => to_tabular(booster), "transform" => [ - %{ - "type" => "filter", - "expr" => "datum.tree_id === treeSelector" - }, %{ "type" => "stratify", "key" => "nodeid", @@ -162,6 +248,10 @@ defmodule EXGBoost.Plotting do "type" => "tree", "size" => [%{"signal" => "width"}, %{"signal" => "height"}], "as" => ["x", "y", "depth", "children"] + }, + %{ + "type" => "filter", + "expr" => "datum.tree_id === selectedTree" } ] }, @@ -190,129 +280,206 @@ defmodule EXGBoost.Plotting do %{"type" => "filter", "expr" => "datum.leaf != null"} ] } - ], - "scales" => [ - %{ - "name" => "color", - "type" => "ordinal", - "domain" => %{"data" => "tree", "field" => "depth"}, - "range" => %{"scheme" => "category10"} - } - ], - "marks" => [ - %{ - "type" => "path", - "from" => %{"data" => "links"}, - "encode" => %{ - "update" => %{ - "path" => %{"field" => "path"}, - "stroke" => %{"value" => "#ccc"} - } - } - }, - %{ - "type" => "text", - "from" => %{"data" => "tree"}, - "encode" => %{ - "enter" => %{ - "text" => %{ - "signal" => - "datum.split ? datum.split + ' <= ' + format(datum.split_condition, '.2f') : ''" - }, - "fontSize" => %{"value" => 10}, - "baseline" => %{"value" => "middle"}, - "align" => %{"value" => "center"}, - "dx" => %{"value" => 0}, - "dy" => %{"value" => 0} - }, - "update" => %{ - "x" => %{"field" => "x"}, - "y" => %{"field" => "y"}, - "fill" => %{"value" => "black"} - } + ] + } + end + + def to_vega(booster, opts \\ []) do + opts = NimbleOptions.validate!(opts, @plotting_schema) |> opts_to_vl_props() + spec = get_data_spec(booster) + spec = Map.merge(spec, opts) + + spec = %{ + spec + | "scales" => [ + %{ + "name" => "color", + "type" => "ordinal", + "domain" => %{"data" => "tree", "field" => "depth"}, + "range" => %{"scheme" => "category10"} } - }, - %{ - "type" => "symbol", - "from" => %{"data" => "innerNodes"}, - "encode" => %{ - "enter" => %{ - "fill" => %{"value" => "none"}, - "shape" => %{"value" => "circle"}, - "x" => %{"field" => "x"}, - "y" => %{"field" => "y"}, - "size" => %{"value" => 800} + ], + "signals" => [ + Map.merge( + %{ + "name" => "selectedTree", + "value" => opts[:index] || 0 }, - "update" => %{ - "fill" => %{"scale" => "color", "field" => "depth"}, - "opacity" => %{"value" => 1}, - "stroke" => %{"value" => "black"}, - "strokeWidth" => %{"value" => 1}, - "fillOpacity" => %{"value" => 0}, - "strokeOpacity" => %{"value" => 1} + if(opts[:index], + do: %{}, + else: %{ + "bind" => %{ + "input" => "select", + "options" => + to_tabular(booster) + |> Enum.reduce([], fn node, acc -> + id = Map.get(node, "tree_id") + if id in acc, do: acc, else: [id | acc] + end) + |> Enum.sort() + } + } + ) + ) + ], + "marks" => [ + %{ + "type" => "path", + "from" => %{"data" => "links"}, + "encode" => %{ + "update" => %{ + "path" => %{"field" => "path"}, + "stroke" => %{"value" => "#ccc"} + } } - } - }, - %{ - "type" => "rect", - "from" => %{"data" => "leafNodes"}, - "encode" => %{ - "enter" => %{ - "x" => %{"signal" => "datum.x - 20"}, - "y" => %{"signal" => "datum.y - 10"}, - "width" => %{"value" => 40}, - "height" => %{"value" => 20}, - "stroke" => %{"value" => "black"}, - "fill" => %{"value" => "transparent"} + }, + %{ + "type" => "text", + "from" => %{"data" => "tree"}, + "encode" => %{ + "enter" => %{ + "text" => %{ + "signal" => + "datum.split ? datum.split + ' <= ' + format(datum.split_condition, '.2f') : ''" + }, + "fontSize" => %{"value" => 10}, + "baseline" => %{"value" => "middle"}, + "align" => %{"value" => "center"}, + "dx" => %{"value" => 0}, + "dy" => %{"value" => 0} + }, + "update" => %{ + "x" => %{"field" => "x"}, + "y" => %{"field" => "y"}, + "fill" => %{"value" => "black"} + } } - } - }, - %{ - "type" => "text", - "from" => %{"data" => "tree"}, - "encode" => %{ - "enter" => %{ - "text" => %{ - "signal" => "format(datum.leaf, '.2f')" + }, + %{ + "type" => "symbol", + "from" => %{"data" => "innerNodes"}, + "encode" => %{ + "enter" => %{ + "fill" => %{"value" => "none"}, + "shape" => %{"value" => "circle"}, + "x" => %{"field" => "x"}, + "y" => %{"field" => "y"}, + "size" => %{"value" => 800} }, - "fontSize" => %{"value" => 10}, - "baseline" => %{"value" => "middle"}, - "align" => %{"value" => "center"}, - "dx" => %{"value" => 0}, - "dy" => %{"value" => 0} - }, - "update" => %{ - "x" => %{"field" => "x"}, - "y" => %{"field" => "y"}, - "fill" => %{"value" => "black"}, - "opacity" => %{"signal" => "datum.leaf != null ? 1 : 0"} + "update" => %{ + "fill" => %{"scale" => "color", "field" => "depth"}, + "opacity" => %{"value" => 1}, + "stroke" => %{"value" => "black"}, + "strokeWidth" => %{"value" => 1}, + "fillOpacity" => %{"value" => 0}, + "strokeOpacity" => %{"value" => 1} + } } - } - }, - %{ - "type" => "text", - "from" => %{"data" => "links"}, - "encode" => %{ - "enter" => %{ - "x" => %{"signal" => "(datum.source.x + datum.target.x) / 2"}, - "y" => %{"signal" => "(datum.source.y + datum.target.y) / 2"}, - "text" => %{ - "signal" => "datum.source.yes === datum.target.nodeid ? 'yes' : 'no'" + }, + %{ + "type" => "rect", + "from" => %{"data" => "leafNodes"}, + "encode" => %{ + "enter" => %{ + "x" => %{"signal" => "datum.x - 20"}, + "y" => %{"signal" => "datum.y - 10"}, + "width" => %{"value" => 40}, + "height" => %{"value" => 20}, + "stroke" => %{"value" => "black"}, + "fill" => %{"value" => "transparent"} + } + } + }, + %{ + "type" => "text", + "from" => %{"data" => "tree"}, + "encode" => %{ + "enter" => %{ + "text" => %{ + "signal" => "format(datum.leaf, '.2f')" + }, + "fontSize" => %{"value" => 10}, + "baseline" => %{"value" => "middle"}, + "align" => %{"value" => "center"}, + "dx" => %{"value" => 0}, + "dy" => %{"value" => 0} }, - "align" => %{"value" => "center"}, - "baseline" => %{"value" => "middle"} + "update" => %{ + "x" => %{"field" => "x"}, + "y" => %{"field" => "y"}, + "fill" => %{"value" => "black"}, + "opacity" => %{"signal" => "datum.leaf != null ? 1 : 0"} + } + } + }, + %{ + "type" => "text", + "from" => %{"data" => "links"}, + "encode" => %{ + "enter" => %{ + "x" => %{"signal" => "(datum.source.x + datum.target.x) / 2"}, + "y" => %{"signal" => "(datum.source.y + datum.target.y) / 2"}, + "text" => %{ + "signal" => "datum.source.yes === datum.target.nodeid ? 'yes' : 'no'" + }, + "align" => %{"value" => "center"}, + "baseline" => %{"value" => "middle"} + } } } - } - ] + ] } + spec = (opts[:validate] && validate_spec(spec)) || spec + Jason.encode!(spec) |> VegaLite.from_json() end + + # Helpers (from https://github.com/livebook-dev/vega_lite/blob/v0.1.8/lib/vega_lite.ex#L1094) + + defp opts_to_vl_props(opts) do + opts |> Map.new() |> to_vl() + end + + defp to_vl(value) when value in [true, false, nil], do: value + + defp to_vl(atom) when is_atom(atom), do: to_vl_key(atom) + + defp to_vl(%_{} = struct), do: struct + + defp to_vl(map) when is_map(map) do + Map.new(map, fn {key, value} -> + {to_vl(key), to_vl(value)} + end) + end + + defp to_vl([{key, _} | _] = keyword) when is_atom(key) do + Map.new(keyword, fn {key, value} -> + {to_vl(key), to_vl(value)} + end) + end + + defp to_vl(list) when is_list(list) do + Enum.map(list, &to_vl/1) + end + + defp to_vl(value), do: value + + defp to_vl_key(key) when is_atom(key) do + key |> to_string() |> snake_to_camel() + end + + defp snake_to_camel(string) do + [part | parts] = String.split(string, "_") + Enum.join([String.downcase(part, :ascii) | Enum.map(parts, &capitalize/1)]) + end + + defp capitalize(<>) when first in ?a..?z, do: <> + defp capitalize(rest), do: rest end defimpl Kino.Render, for: EXGBoost.Booster do def to_livebook(booster) do - Kino.Render.to_livebook(EXGBoost.Plotting.to_vega(booster)) + Kino.Render.to_livebook(EXGBoost.Plotting.to_vega(booster) |> Kino.Render.to_livebook()) end end diff --git a/mix.exs b/mix.exs index d31685f..c2a68b4 100644 --- a/mix.exs +++ b/mix.exs @@ -49,7 +49,7 @@ defmodule EXGBoost.MixProject do {:ex_doc, "~> 0.29.0", only: :docs}, {:cc_precompiler, "~> 0.1.0", runtime: false}, {:exterval, "0.1.0"}, - {:exonerate, "~> 1.1", runtime: false}, + {:ex_json_schema, "~> 0.10.2"}, {:httpoison, "~> 2.0", runtime: false}, {:vega_lite, "~> 0.1"}, {:kino, "~> 0.11"} diff --git a/mix.lock b/mix.lock index c313017..ad391b0 100644 --- a/mix.lock +++ b/mix.lock @@ -3,9 +3,11 @@ "cc_precompiler": {:hex, :cc_precompiler, "0.1.7", "77de20ac77f0e53f20ca82c563520af0237c301a1ec3ab3bc598e8a96c7ee5d9", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "2768b28bf3c2b4f788c995576b39b8cb5d47eb788526d93bd52206c1d8bf4b75"}, "certifi": {:hex, :certifi, "2.12.0", "2d1cca2ec95f59643862af91f001478c9863c2ac9cb6e2f89780bfd8de987329", [:rebar3], [], "hexpm", "ee68d85df22e554040cdb4be100f33873ac6051387baf6a8f6ce82272340ff1c"}, "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, + "decimal": {:hex, :decimal, "2.1.1", "5611dca5d4b2c3dd497dec8f68751f1f1a54755e8ed2a966c2633cf885973ad6", [:mix], [], "hexpm", "53cfe5f497ed0e7771ae1a475575603d77425099ba5faef9394932b35020ffcc"}, "earmark_parser": {:hex, :earmark_parser, "1.4.32", "fa739a0ecfa34493de19426681b23f6814573faee95dfd4b4aafe15a7b5b32c6", [:mix], [], "hexpm", "b8b0dd77d60373e77a3d7e8afa598f325e49e8663a51bcc2b88ef41838cca755"}, "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, "ex_doc": {:hex, :ex_doc, "0.29.4", "6257ecbb20c7396b1fe5accd55b7b0d23f44b6aa18017b415cb4c2b91d997729", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "2c6699a737ae46cb61e4ed012af931b57b699643b24dabe2400a8168414bc4f5"}, + "ex_json_schema": {:hex, :ex_json_schema, "0.10.2", "7c4b8c1481fdeb1741e2ce66223976edfb9bccebc8014f6aec35d4efe964fb71", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "37f43be60f8407659d4d0155a7e45e7f406dab1f827051d3d35858a709baf6a6"}, "exonerate": {:hex, :exonerate, "1.1.3", "96c3249f0358567151d595a49d612849a7a57930575a1eb781c178f2e84b1559", [:mix], [{:finch, "~> 0.15", [hex: :finch, repo: "hexpm", optional: true]}, {:idna, "~> 6.1.1", [hex: :idna, repo: "hexpm", optional: true]}, {:jason, "~> 1.4.0", [hex: :jason, repo: "hexpm", optional: false]}, {:json_ptr, "~> 1.0", [hex: :json_ptr, repo: "hexpm", optional: false]}, {:match_spec, "~> 0.3.1", [hex: :match_spec, repo: "hexpm", optional: false]}, {:pegasus, "~> 0.2.2", [hex: :pegasus, repo: "hexpm", optional: true]}, {:req, "~> 0.3", [hex: :req, repo: "hexpm", optional: true]}, {:yaml_elixir, "~> 2.7", [hex: :yaml_elixir, repo: "hexpm", optional: true]}], "hexpm", "cb750e303ae79987153eecb474ccd011f45e3dea7c4f988d8286facdf56db1d2"}, "exterval": {:hex, :exterval, "0.1.0", "d35ec43b0f260239f859665137fac0974f1c6a8d50bf8d52b5999c87c67c63e5", [:mix], [], "hexpm", "8ed444558a501deec6563230e3124cdf242c413a95eb9ca9f39de024ad779d7f"}, "finch": {:hex, :finch, "0.16.0", "40733f02c89f94a112518071c0a91fe86069560f5dbdb39f9150042f44dcfb1a", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:mint, "~> 1.3", [hex: :mint, repo: "hexpm", optional: false]}, {:nimble_options, "~> 0.4 or ~> 1.0", [hex: :nimble_options, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 0.2.6 or ~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "f660174c4d519e5fec629016054d60edd822cdfe2b7270836739ac2f97735ec5"}, From a1a52bf14e7a54765ba64a17e3791260974b07ef Mon Sep 17 00:00:00 2001 From: acalejos Date: Tue, 12 Dec 2023 22:04:51 -0500 Subject: [PATCH 03/21] add plotting --- lib/exgboost.ex | 16 +- lib/exgboost/plotting.ex | 669 ++++++++++++++++++++++++------ lib/exgboost/training.ex | 28 +- lib/exgboost/training/callback.ex | 1 - test/exgboost_test.exs | 14 + 5 files changed, 581 insertions(+), 147 deletions(-) diff --git a/lib/exgboost.ex b/lib/exgboost.ex index d893d2b..9b3bb07 100644 --- a/lib/exgboost.ex +++ b/lib/exgboost.ex @@ -558,13 +558,15 @@ defmodule EXGBoost do * `:format` - the format to export the graphic as, must be either of: `:json`, `:html`, `:png`, `:svg`, `:pdf`. By default the format is inferred from the file extension. * `:local_npm_prefix` - a relative path pointing to a local npm project directory where the necessary npm packages are installed. For instance, in Phoenix projects you may want to pass local_npm_prefix: "assets". By default the npm packages are searched for in the current directory and globally. """ - def plot_tree(booster, path, opts \\ []) when is_binary(path) do - vega = Plotting.to_vega(booster) - - if Keyword.get(opts, :path, nil) == nil do - :ok = VegaLite.Export.save!(vega, path, opts) + def plot_tree(booster, opts \\ []) do + {path, opts} = Keyword.pop(opts, :path) + {save_opts, opts} = Keyword.split(opts, [:format, :local_npm_prefix]) + vega = Plotting.to_vega(booster, opts) + + if path != nil do + VegaLite.Export.save!(vega, path, save_opts) + else + vega end - - vega end end diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex index dc1a972..9e64a47 100644 --- a/lib/exgboost/plotting.ex +++ b/lib/exgboost/plotting.ex @@ -82,6 +82,18 @@ defmodule EXGBoost.Plotting do |> ExJsonSchema.Schema.resolve() @plotting_params [ + autosize: [ + doc: + "Determines if the visualization should automatically resize when the window size changes", + type: {:in, ["fit", "pad", "fit-x", "fit-y", "none"]}, + default: "fit" + ], + background: [ + doc: + "The background color of the visualization. Accepts a valid CSS color string. For example: `#f304d3`, `#ccc`, `rgb(253, 12, 134)`, `steelblue.`", + type: :string, + default: "#f5f5f5" + ], height: [ doc: "Height of the plot in pixels", type: :pos_integer, @@ -108,20 +120,71 @@ defmodule EXGBoost.Plotting do ]}, default: 5 ], - leafs: [ + leaves: [ doc: "Specifies characteristics of leaf nodes", type: :keyword_list, - default: [fill: "#ccc", stroke: "black", strokeWidth: 1] + keys: [ + fill: [type: :string, default: "#ccc"], + stroke: [type: :string, default: "black"], + stroke_width: [type: :pos_integer, default: 1], + font_size: [type: :pos_integer, default: 13] + ], + default: [ + fill: "#ccc", + stroke: "black", + stroke_width: 1, + font_size: 13 + ] + ], + node_width: [ + doc: "The width of each node in pixels", + type: :pos_integer, + default: 100 + ], + node_height: [ + doc: "The height of each node in pixels", + type: :pos_integer, + default: 45 + ], + space_between: [ + type: :keyword_list, + keys: [ + nodes: [type: :pos_integer, default: 10], + levels: [type: :pos_integer, default: 100] + ], + default: [nodes: 10, levels: 100] ], - inner_nodes: [ + splits: [ doc: "Specifies characteristics of inner nodes", - type: :boolean, - default: true + type: :keyword_list, + keys: [ + fill: [type: :string, default: "#ccc"], + stroke: [type: :string, default: "black"], + stroke_width: [type: :pos_integer, default: 1], + font_size: [type: :pos_integer, default: 13] + ], + default: [ + fill: "#ccc", + stroke: "black", + stroke_width: 1, + font_size: 13 + ] ], links: [ doc: "Specifies characteristics of links between nodes", type: :keyword_list, - default: [stroke: "#ccc", strokeWidth: 1] + keys: [ + fill: [type: :string, default: "#ccc"], + stroke: [type: :string, default: "black"], + stroke_width: [type: :pos_integer, default: 1], + font_size: [type: :pos_integer, default: 13] + ], + default: [ + fill: "#ccc", + stroke: "black", + stroke_width: 1, + font_size: 13 + ] ], validate: [ doc: "Whether to validate the Vega specification against the Vega schema", @@ -131,8 +194,14 @@ defmodule EXGBoost.Plotting do index: [ doc: "The zero-indexed index of the tree to plot. If `nil`, plots all trees using a dropdown selector to switch between trees", - type: {:or, [nil, :pos_integer]}, + type: {:or, [nil, :non_neg_integer]}, default: 0 + ], + depth: [ + doc: + "The depth of the tree to plot. If `nil`, plots all levels (cick on a node to expand/collapse)", + type: {:or, [nil, :non_neg_integer]}, + default: nil ] ] @@ -141,13 +210,6 @@ defmodule EXGBoost.Plotting do def get_schema(), do: @schema defp validate_spec(spec) do - unless Code.ensure_loaded?(EXJsonSchema) do - raise( - RuntimeError, - "The `ex_json_schema` package must be installed to validate Vega specifications. Please install it by running `mix deps.get`" - ) - end - case ExJsonSchema.Validator.validate(get_schema(), spec) do :ok -> spec @@ -163,7 +225,12 @@ defmodule EXGBoost.Plotting do defp _to_tabular(idx, node, parent) do node = Map.put(node, "tree_id", idx) node = if parent, do: Map.put(node, "parentid", parent), else: node - node = new_node(node) + + node = + new_node(node) + |> Map.update("nodeid", nil, &(&1 + 1)) + |> Map.update("yes", nil, &if(&1, do: &1 + 1)) + |> Map.update("no", nil, &if(&1, do: &1 + 1)) if Map.has_key?(node, "children") do [ @@ -235,49 +302,207 @@ defmodule EXGBoost.Plotting do %{ "$schema" => "https://vega.github.io/schema/vega/v5.json", "data" => [ + %{"name" => "tree", "values" => to_tabular(booster)}, %{ - "name" => "tree", - "values" => to_tabular(booster), + "name" => "treeCalcs", + "source" => "tree", "transform" => [ + %{"expr" => "datum.tree_id === selectedTree", "type" => "filter"}, + %{"key" => "nodeid", "parentKey" => "parentid", "type" => "stratify"}, %{ - "type" => "stratify", - "key" => "nodeid", - "parentKey" => "parentid" + "as" => ["y", "x", "depth", "children"], + "method" => "tidy", + "separation" => %{"signal" => "false"}, + "type" => "tree" + } + ] + }, + %{ + "name" => "treeChildren", + "source" => "treeCalcs", + "transform" => [ + %{ + "as" => ["childrenObjects"], + "fields" => ["parentid"], + "groupby" => ["parentid"], + "ops" => ["values"], + "type" => "aggregate" }, %{ - "type" => "tree", - "size" => [%{"signal" => "width"}, %{"signal" => "height"}], - "as" => ["x", "y", "depth", "children"] + "as" => "childrenIds", + "expr" => "pluck(datum.childrenObjects,'nodeid')", + "type" => "formula" + } + ] + }, + %{ + "name" => "treeAncestors", + "source" => "treeCalcs", + "transform" => [ + %{ + "as" => "treeAncestors", + "expr" => "treeAncestors('treeCalcs', datum.nodeid, 'root')", + "type" => "formula" }, + %{"fields" => ["treeAncestors"], "type" => "flatten"}, %{ - "type" => "filter", - "expr" => "datum.tree_id === selectedTree" + "as" => "allParents", + "expr" => "datum.treeAncestors.parentid", + "type" => "formula" } ] }, %{ - "name" => "links", - "source" => "tree", + "name" => "treeChildrenAll", + "source" => "treeAncestors", "transform" => [ - %{"type" => "treelinks"}, %{ - "type" => "linkpath", - "shape" => "line" + "fields" => [ + "allParents", + "nodeid", + "name", + "parentid", + "x", + "y", + "depth", + "children" + ], + "type" => "project" + }, + %{ + "as" => ["allChildrenObjects", "allChildrenCount", "id"], + "fields" => ["parentid", "parentid", "nodeid"], + "groupby" => ["allParents"], + "ops" => ["values", "count", "min"], + "type" => "aggregate" + }, + %{ + "as" => "allChildrenIds", + "expr" => "pluck(datum.allChildrenObjects,'nodeid')", + "type" => "formula" } ] }, %{ - "name" => "innerNodes", - "source" => "tree", + "name" => "treeClickStoreTemp", + "source" => "treeAncestors", "transform" => [ - %{"type" => "filter", "expr" => "datum.leaf == null"} + %{ + "expr" => + "startingDepth != -1 ? datum.depth <= startingDepth : node != 0 && !isExpanded ? datum.parentid == node: node != 0 && isExpanded ? datum.allParents == node : false", + "type" => "filter" + }, + %{ + "fields" => ["nodeid", "parentid", "x", "y", "depth", "children"], + "type" => "project" + }, + %{ + "fields" => ["nodeid"], + "groupby" => ["nodeid", "parentid", "x", "y", "depth", "children"], + "ops" => ["min"], + "type" => "aggregate" + } ] }, %{ - "name" => "leafNodes", + "name" => "treeClickStorePerm", + "on" => [ + %{"insert" => "data('treeClickStoreTemp')", "trigger" => "startingDepth >= 0"}, + %{ + "insert" => "!isExpanded ? data('treeClickStoreTemp'): false", + "trigger" => "node" + }, + %{"remove" => "isExpanded ? data('treeClickStoreTemp'): false", "trigger" => "node"} + ], + "values" => [] + }, + %{ + "name" => "treeLayout", "source" => "tree", "transform" => [ - %{"type" => "filter", "expr" => "datum.leaf != null"} + %{"expr" => "datum.tree_id === selectedTree", "type" => "filter"}, + %{ + "expr" => "indata('treeClickStorePerm', 'nodeid', datum.nodeid)", + "type" => "filter" + }, + %{"key" => "nodeid", "parentKey" => "parentid", "type" => "stratify"}, + %{ + "as" => ["x", "y", "depth", "children"], + "method" => "tidy", + "nodeSize" => [ + %{"signal" => "nodeWidth + spaceBetweenNodes"}, + %{"signal" => "nodeHeight+ spaceBetweenLevels"} + ], + "separation" => %{"signal" => "false"}, + "type" => "tree" + }, + %{"as" => "y", "expr" => "datum.y+(height/10)", "type" => "formula"}, + %{"as" => "x", "expr" => "datum.x+(width/2)", "type" => "formula"}, + %{"as" => "xscaled", "expr" => "scale('xscale',datum.x)", "type" => "formula"}, + %{"as" => "parent", "expr" => "datum.parentid", "type" => "formula"} + ] + }, + %{ + "name" => "fullTreeLayout", + "source" => "treeLayout", + "transform" => [ + %{ + "fields" => ["nodeid"], + "from" => "treeChildren", + "key" => "parentid", + "type" => "lookup", + "values" => ["childrenObjects", "childrenIds"] + }, + %{ + "fields" => ["nodeid"], + "from" => "treeChildrenAll", + "key" => "allParents", + "type" => "lookup", + "values" => ["allChildrenIds", "allChildrenObjects"] + }, + %{ + "fields" => ["nodeid"], + "from" => "treeCalcs", + "key" => "nodeid", + "type" => "lookup", + "values" => ["children"] + }, + %{ + "as" => "treeParent", + "expr" => "reverse(pluck(treeAncestors('treeCalcs', datum.nodeid), 'nodeid'))[1]", + "type" => "formula" + }, + %{"as" => "isLeaf", "expr" => "datum.leaf == null", "type" => "formula"} + ] + }, + %{ + "name" => "visibleNodes", + "source" => "fullTreeLayout", + "transform" => [ + %{ + "expr" => "indata('treeClickStorePerm', 'nodeid', datum.nodeid)", + "type" => "filter" + } + ] + }, + %{ + "name" => "links", + "source" => "treeLayout", + "transform" => [ + %{"type" => "treelinks"}, + %{ + "orient" => "vertical", + "shape" => "line", + "sourceX" => %{"expr" => "scale('xscale', datum.source.x)"}, + "sourceY" => %{"expr" => "scale('yscale', datum.source.y)"}, + "targetX" => %{"expr" => "scale('xscale', datum.target.x)"}, + "targetY" => %{"expr" => "scale('yscale', datum.target.y) - scaledNodeHeight"}, + "type" => "linkpath" + }, + %{ + "expr" => " indata('treeClickStorePerm', 'nodeid', datum.target.nodeid)", + "type" => "filter" + } ] } ] @@ -285,18 +510,148 @@ defmodule EXGBoost.Plotting do end def to_vega(booster, opts \\ []) do - opts = NimbleOptions.validate!(opts, @plotting_schema) |> opts_to_vl_props() + opts = NimbleOptions.validate!(opts, @plotting_schema) spec = get_data_spec(booster) - spec = Map.merge(spec, opts) - spec = %{ - spec - | "scales" => [ + [%{"$ref" => root} | [%{"properties" => properties}]] = + EXGBoost.Plotting.get_schema().schema |> Map.get("allOf") + + top_level_keys = + (ExJsonSchema.Schema.get_fragment!(EXGBoost.Plotting.get_schema(), root) + |> Map.get("properties") + |> Map.keys()) ++ Map.keys(properties) + + tlk = opts |> opts_to_vl_props() |> Map.take(top_level_keys) + + spec = + Map.merge(spec, %{ + "$schema" => "https://vega.github.io/schema/vega/v5.json", + "marks" => [ %{ - "name" => "color", - "type" => "ordinal", - "domain" => %{"data" => "tree", "field" => "depth"}, - "range" => %{"scheme" => "category10"} + "encode" => %{ + "update" => %{ + "path" => %{"field" => "path"}, + "stroke" => %{ + "signal" => "datum.source.yes === datum.target.nodeid ? 'red' : 'black'" + }, + "strokeWidth" => %{ + "signal" => "indexof(nodeHighlight, datum.target.nodeid)> -1? 2:1" + } + } + }, + "from" => %{"data" => "links"}, + "interactive" => false, + "type" => "path" + }, + %{ + "encode" => %{ + "update" => %{ + "align" => %{"value" => "center"}, + "baseline" => %{"value" => "middle"}, + "fontSize" => %{"signal" => "linksFontSize"}, + "text" => %{"signal" => "datum.source.yes === datum.target.nodeid ? 'yes' : 'no'"}, + "x" => %{ + "signal" => + "(scale('xscale', datum.source.x+(nodeWidth/3)) + scale('xscale', datum.target.x)) / 2" + }, + "y" => %{ + "signal" => + "(scale('yscale', datum.source.y) + scale('yscale', datum.target.y)) / 2 - (scaledNodeHeight/2)" + } + } + }, + "from" => %{"data" => "links"}, + "type" => "text" + }, + %{ + "clip" => false, + "description" => "The parent node", + "encode" => %{ + "update" => %{ + "cornerRadius" => %{"value" => 2}, + "cursor" => %{"signal" => "datum.children > 0 ? 'pointer' : '' "}, + "fill" => %{"signal" => "datum.isLeaf ? leafColor : splitColor"}, + "height" => %{"signal" => "scaledNodeHeight"}, + "opacity" => %{"value" => 0}, + "tooltip" => %{"signal" => ""}, + "width" => %{"signal" => "scaledNodeWidth"}, + "x" => %{"signal" => "datum.xscaled - (scaledNodeWidth / 2)"}, + "yc" => %{"signal" => "scale('yscale',datum.y) - (scaledNodeHeight/2)"} + } + }, + "from" => %{"data" => "visibleNodes"}, + "marks" => [ + %{ + "description" => + "highlight (seems like a Vega bug as this doens't work on the group element)", + "encode" => %{ + "update" => %{ + "fill" => %{"signal" => "parent.isLeaf ? leafColor : splitColor"}, + "height" => %{"signal" => "item.mark.group.height"}, + "width" => %{"signal" => "item.mark.group.width"}, + "x" => %{"signal" => "item.mark.group.x1"}, + "y" => %{"signal" => "0"} + } + }, + "interactive" => false, + "name" => "highlight", + "type" => "rect" + }, + %{ + "encode" => %{ + "update" => %{ + "align" => %{"value" => "center"}, + "baseline" => %{"value" => "middle"}, + "fill" => %{"signal" => "'black'"}, + "font" => %{"value" => "Calibri"}, + "fontSize" => %{"signal" => "parent.split ? splitsFontSize : leavesFontSize"}, + "limit" => %{"signal" => "scaledNodeWidth-scaledLimit"}, + "text" => %{ + "signal" => + "parent.split ? parent.split + ' <= ' + format(parent.split_condition, '.2f') : 'leaf = ' + format(parent.leaf, '.2f')" + }, + "x" => %{"signal" => "(scaledNodeWidth / 2)"}, + "y" => %{"signal" => "scaledNodeHeight / 2"} + } + }, + "interactive" => false, + "name" => "title", + "type" => "text" + }, + %{ + "encode" => %{ + "update" => %{ + "align" => %{"value" => "right"}, + "baseline" => %{"value" => "middle"}, + "fill" => %{"signal" => "childrenColor"}, + "font" => %{"value" => "Calibri"}, + "fontSize" => %{"signal" => "splitsFontSize"}, + "text" => %{"signal" => "parent.children > 0 ? parent.children :''"}, + "x" => %{"signal" => "item.mark.group.width - (9/ span(xdom))*width"}, + "y" => %{"signal" => "item.mark.group.height/2"} + } + }, + "interactive" => false, + "name" => "node children", + "type" => "text" + } + ], + "name" => "node", + "type" => "group" + } + ], + "scales" => [ + %{ + "domain" => %{"signal" => "xdom"}, + "name" => "xscale", + "range" => %{"signal" => "xrange"}, + "zero" => false + }, + %{ + "domain" => %{"signal" => "ydom"}, + "name" => "yscale", + "range" => %{"signal" => "yrange"}, + "zero" => false } ], "signals" => [ @@ -320,116 +675,164 @@ defmodule EXGBoost.Plotting do } } ) - ) - ], - "marks" => [ + ), + %{"name" => "childrenColor", "value" => "black"}, + %{"name" => "leafColor", "value" => opts[:leaves][:fill]}, + %{"name" => "splitColor", "value" => opts[:splits][:fill]}, + %{"name" => "spaceBetweenLevels", "value" => opts[:space_between][:levels]}, + %{"name" => "spaceBetweenNodes", "value" => opts[:space_between][:nodes]}, + %{"name" => "nodeWidth", "value" => opts[:node_width]}, + %{"name" => "nodeHeight", "value" => opts[:node_height]}, %{ - "type" => "path", - "from" => %{"data" => "links"}, - "encode" => %{ - "update" => %{ - "path" => %{"field" => "path"}, - "stroke" => %{"value" => "#ccc"} + "name" => "startingDepth", + "on" => [%{"events" => %{"throttle" => 0, "type" => "timer"}, "update" => "-1"}], + "value" => + opts[:depth] || + to_tabular(booster) + |> Enum.map(&Map.get(&1, "depth")) + |> Enum.filter(& &1) + |> Enum.max() + }, + %{ + "name" => "node", + "on" => [ + %{ + "events" => %{"markname" => "node", "type" => "click"}, + "update" => "datum.nodeid" } - } + ], + "value" => 0 }, %{ - "type" => "text", - "from" => %{"data" => "tree"}, - "encode" => %{ - "enter" => %{ - "text" => %{ - "signal" => - "datum.split ? datum.split + ' <= ' + format(datum.split_condition, '.2f') : ''" - }, - "fontSize" => %{"value" => 10}, - "baseline" => %{"value" => "middle"}, - "align" => %{"value" => "center"}, - "dx" => %{"value" => 0}, - "dy" => %{"value" => 0} + "name" => "nodeHighlight", + "on" => [ + %{ + "events" => %{"markname" => "node", "type" => "mouseover"}, + "update" => "pluck(treeAncestors('treeCalcs', datum.nodeid), 'nodeid')" }, - "update" => %{ - "x" => %{"field" => "x"}, - "y" => %{"field" => "y"}, - "fill" => %{"value" => "black"} + %{"events" => %{"type" => "mouseout"}, "update" => "[0]"} + ], + "value" => "[0]" + }, + %{ + "name" => "isExpanded", + "on" => [ + %{ + "events" => %{"markname" => "node", "type" => "click"}, + "update" => + "datum.children > 0 && indata('treeClickStorePerm', 'nodeid', datum.childrenIds[0]) ? true : false" } - } + ], + "value" => 0 }, + %{"name" => "xrange", "update" => "[0, width]"}, + %{"name" => "yrange", "update" => "[0, height]"}, %{ - "type" => "symbol", - "from" => %{"data" => "innerNodes"}, - "encode" => %{ - "enter" => %{ - "fill" => %{"value" => "none"}, - "shape" => %{"value" => "circle"}, - "x" => %{"field" => "x"}, - "y" => %{"field" => "y"}, - "size" => %{"value" => 800} - }, - "update" => %{ - "fill" => %{"scale" => "color", "field" => "depth"}, - "opacity" => %{"value" => 1}, - "stroke" => %{"value" => "black"}, - "strokeWidth" => %{"value" => 1}, - "fillOpacity" => %{"value" => 0}, - "strokeOpacity" => %{"value" => 1} + "name" => "down", + "on" => [%{"events" => "mousedown", "update" => "xy()"}], + "value" => nil + }, + %{ + "name" => "xcur", + "on" => [%{"events" => "mousedown", "update" => "slice(xdom)"}], + "value" => nil + }, + %{ + "name" => "ycur", + "on" => [%{"events" => "mousedown", "update" => "slice(ydom)"}], + "value" => nil + }, + %{ + "name" => "delta", + "on" => [ + %{ + "events" => [ + %{ + "between" => [ + %{"type" => "mousedown"}, + %{"source" => "window", "type" => "mouseup"} + ], + "consume" => true, + "source" => "window", + "type" => "mousemove" + } + ], + "update" => "down ? [down[0]-x(), down[1]-y()] : [0,0]" } - } + ], + "value" => [0, 0] }, %{ - "type" => "rect", - "from" => %{"data" => "leafNodes"}, - "encode" => %{ - "enter" => %{ - "x" => %{"signal" => "datum.x - 20"}, - "y" => %{"signal" => "datum.y - 10"}, - "width" => %{"value" => 40}, - "height" => %{"value" => 20}, - "stroke" => %{"value" => "black"}, - "fill" => %{"value" => "transparent"} + "name" => "anchor", + "on" => [ + %{"events" => "wheel", "update" => "[invert('xscale', x()), invert('yscale', y())]"} + ], + "value" => [0, 0] + }, + %{"name" => "xext", "update" => "[0,width]"}, + %{"name" => "yext", "update" => "[0,height]"}, + %{ + "name" => "zoom", + "on" => [ + %{ + "events" => "wheel!", + "force" => true, + "update" => "pow(1.001, event.deltaY * pow(16, event.deltaMode))" } - } + ], + "value" => 1 }, %{ - "type" => "text", - "from" => %{"data" => "tree"}, - "encode" => %{ - "enter" => %{ - "text" => %{ - "signal" => "format(datum.leaf, '.2f')" - }, - "fontSize" => %{"value" => 10}, - "baseline" => %{"value" => "middle"}, - "align" => %{"value" => "center"}, - "dx" => %{"value" => 0}, - "dy" => %{"value" => 0} + "name" => "xdom", + "on" => [ + %{ + "events" => %{"signal" => "delta"}, + "update" => + "[xcur[0] + span(xcur) * delta[0] / width, xcur[1] + span(xcur) * delta[0] / width]" }, - "update" => %{ - "x" => %{"field" => "x"}, - "y" => %{"field" => "y"}, - "fill" => %{"value" => "black"}, - "opacity" => %{"signal" => "datum.leaf != null ? 1 : 0"} + %{ + "events" => %{"signal" => "zoom"}, + "update" => + "[anchor[0] + (xdom[0] - anchor[0]) * zoom, anchor[0] + (xdom[1] - anchor[0]) * zoom]" } - } + ], + "update" => "slice(xext)" }, %{ - "type" => "text", - "from" => %{"data" => "links"}, - "encode" => %{ - "enter" => %{ - "x" => %{"signal" => "(datum.source.x + datum.target.x) / 2"}, - "y" => %{"signal" => "(datum.source.y + datum.target.y) / 2"}, - "text" => %{ - "signal" => "datum.source.yes === datum.target.nodeid ? 'yes' : 'no'" - }, - "align" => %{"value" => "center"}, - "baseline" => %{"value" => "middle"} + "name" => "ydom", + "on" => [ + %{ + "events" => %{"signal" => "delta"}, + "update" => + "[ycur[0] + span(ycur) * delta[1] / height, ycur[1] + span(ycur) * delta[1] / height]" + }, + %{ + "events" => %{"signal" => "zoom"}, + "update" => + "[anchor[1] + (ydom[0] - anchor[1]) * zoom, anchor[1] + (ydom[1] - anchor[1]) * zoom]" } - } + ], + "update" => "slice(yext)" + }, + %{"name" => "scaledNodeWidth", "update" => "(nodeWidth/ span(xdom))*width"}, + %{"name" => "scaledNodeHeight", "update" => "abs(nodeHeight/ span(ydom))*height"}, + %{"name" => "scaledLimit", "update" => "(20/ span(xdom))*width"}, + %{ + "name" => "leavesFontSize", + "update" => "(#{opts[:leaves][:font_size]}/ span(xdom))*width" + }, + %{ + "name" => "splitsFontSize", + "update" => "(#{opts[:splits][:font_size]}/ span(xdom))*width" + }, + %{ + "name" => "linksFontSize", + "update" => "(#{opts[:links][:font_size]}/ span(xdom))*width" } ] - } + }) + spec = Map.merge(spec, tlk) spec = (opts[:validate] && validate_spec(spec)) || spec Jason.encode!(spec) |> VegaLite.from_json() @@ -480,6 +883,6 @@ end defimpl Kino.Render, for: EXGBoost.Booster do def to_livebook(booster) do - Kino.Render.to_livebook(EXGBoost.Plotting.to_vega(booster) |> Kino.Render.to_livebook()) + EXGBoost.Plotting.to_vega(booster) |> Kino.Render.to_livebook() end end diff --git a/lib/exgboost/training.ex b/lib/exgboost/training.ex index 624fa39..41f5a76 100644 --- a/lib/exgboost/training.ex +++ b/lib/exgboost/training.ex @@ -91,14 +91,30 @@ defmodule EXGBoost.Training do callbacks = unless is_nil(opts[:early_stopping_rounds]) do unless evals_dmats == [] do + # TODO This currently works due a 1-line addition made to XGBoost that + # is done within a sed command in the Makefile. + # This is necessary because the default metrics are not serialized + # with the booster config, so if we didn't do this we would have to + # run 1 iteration of evaluation before we could get the default + + # TODO: This is a hack to get around the aforementioned issue [{_dmat, target_eval} | _tail] = Enum.reverse(evals_dmats) - # Default to the last metric - [%{"name" => metric_name} | _tail] = - EXGBoost.dump_config(bst) - |> Jason.decode!() - |> get_in(["learner", "metrics"]) - |> Enum.reverse() + %{"learner" => %{"metrics" => metrics, "default_metric" => default_metric}} = + EXGBoost.dump_config(bst) |> Jason.decode!() + + metric_name = + cond do + Enum.empty?(metrics) && opts[:disable_default_eval_metric] -> + raise ArgumentError, + "`:early_stopping_rounds` requires at least one evaluation set. This means you have likely set `disable_default_eval_metric: true` and have not set any explicit evalutation metrics. Please supply at least one metric in the `:eval_metric` option or set `disable_default_eval_metric: false` (default option)" + + Enum.empty?(metrics) -> + default_metric + + true -> + metrics |> Enum.reverse() |> hd() |> Map.fetch!("name") + end [ %Callback{ diff --git a/lib/exgboost/training/callback.ex b/lib/exgboost/training/callback.ex index 15b69ee..95794bb 100644 --- a/lib/exgboost/training/callback.ex +++ b/lib/exgboost/training/callback.ex @@ -211,7 +211,6 @@ defmodule EXGBoost.Training.Callback do |> Map.filter(filter) {:cont, %{state | metrics: metrics}} - {:cont, %{state | metrics: metrics}} end @doc """ diff --git a/test/exgboost_test.exs b/test/exgboost_test.exs index 9b15aca..2546331 100644 --- a/test/exgboost_test.exs +++ b/test/exgboost_test.exs @@ -151,6 +151,20 @@ defmodule EXGBoostTest do refute is_nil(booster.best_iteration) refute is_nil(booster.best_score) + + {booster, _} = + ExUnit.CaptureIO.with_io(fn -> + EXGBoost.train(x, y, + disable_default_eval_metric: true, + num_boost_rounds: 10, + early_stopping_rounds: 1, + evals: [{x, y, "validation"}], + tree_method: :hist + ) + end) + + refute is_nil(booster.best_iteration) + refute is_nil(booster.best_score) end test "eval with multiple metrics", context do From a680d94957e043199c5af61082369c2bf6524206 Mon Sep 17 00:00:00 2001 From: acalejos Date: Wed, 13 Dec 2023 23:49:35 -0500 Subject: [PATCH 04/21] Separate spec into components --- lib/exgboost/plotting.ex | 200 ++++++++++++++++++++++++++++++--------- 1 file changed, 156 insertions(+), 44 deletions(-) diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex index 9e64a47..df45ac6 100644 --- a/lib/exgboost/plotting.ex +++ b/lib/exgboost/plotting.ex @@ -81,6 +81,8 @@ defmodule EXGBoost.Plotting do |> Jason.decode!() |> ExJsonSchema.Schema.resolve() + @mark_params [] + @plotting_params [ autosize: [ doc: @@ -124,16 +126,12 @@ defmodule EXGBoost.Plotting do doc: "Specifies characteristics of leaf nodes", type: :keyword_list, keys: [ - fill: [type: :string, default: "#ccc"], - stroke: [type: :string, default: "black"], - stroke_width: [type: :pos_integer, default: 1], - font_size: [type: :pos_integer, default: 13] + text: [], + rect: [] ], default: [ - fill: "#ccc", - stroke: "black", - stroke_width: 1, - font_size: 13 + text: [], + rect: [] ] ], node_width: [ @@ -174,16 +172,8 @@ defmodule EXGBoost.Plotting do doc: "Specifies characteristics of links between nodes", type: :keyword_list, keys: [ - fill: [type: :string, default: "#ccc"], - stroke: [type: :string, default: "black"], - stroke_width: [type: :pos_integer, default: 1], - font_size: [type: :pos_integer, default: 13] - ], - default: [ - fill: "#ccc", - stroke: "black", - stroke_width: 1, - font_size: 13 + yes: [], + no: [] ] ], validate: [ @@ -485,6 +475,26 @@ defmodule EXGBoost.Plotting do } ] }, + %{ + "name" => "splitNodes", + "source" => "visibleNodes", + "transform" => [ + %{ + "type" => "filter", + "expr" => "datum.isLeaf" + } + ] + }, + %{ + "name" => "leafNodes", + "source" => "visibleNodes", + "transform" => [ + %{ + "type" => "filter", + "expr" => "!datum.isLeaf" + } + ] + }, %{ "name" => "links", "source" => "treeLayout", @@ -504,6 +514,26 @@ defmodule EXGBoost.Plotting do "type" => "filter" } ] + }, + %{ + "name" => "yesPaths", + "source" => "links", + "transform" => [ + %{ + "type" => "filter", + "expr" => "datum.source.yes === datum.target.nodeid " + } + ] + }, + %{ + "name" => "noPaths", + "source" => "links", + "transform" => [ + %{ + "type" => "filter", + "expr" => "datum.source.yes !== datum.target.nodeid " + } + ] } ] } @@ -539,7 +569,23 @@ defmodule EXGBoost.Plotting do } } }, - "from" => %{"data" => "links"}, + "from" => %{"data" => "yesPaths"}, + "interactive" => false, + "type" => "path" + }, + %{ + "encode" => %{ + "update" => %{ + "path" => %{"field" => "path"}, + "stroke" => %{ + "signal" => "datum.source.yes === datum.target.nodeid ? 'red' : 'black'" + }, + "strokeWidth" => %{ + "signal" => "indexof(nodeHighlight, datum.target.nodeid)> -1? 2:1" + } + } + }, + "from" => %{"data" => "noPaths"}, "interactive" => false, "type" => "path" }, @@ -560,43 +606,47 @@ defmodule EXGBoost.Plotting do } } }, - "from" => %{"data" => "links"}, + "from" => %{"data" => "yesPaths"}, + "type" => "text" + }, + %{ + "encode" => %{ + "update" => %{ + "align" => %{"value" => "center"}, + "baseline" => %{"value" => "middle"}, + "fontSize" => %{"signal" => "linksFontSize"}, + "text" => %{"signal" => "datum.source.yes === datum.target.nodeid ? 'yes' : 'no'"}, + "x" => %{ + "signal" => + "(scale('xscale', datum.source.x+(nodeWidth/3)) + scale('xscale', datum.target.x)) / 2" + }, + "y" => %{ + "signal" => + "(scale('yscale', datum.source.y) + scale('yscale', datum.target.y)) / 2 - (scaledNodeHeight/2)" + } + } + }, + "from" => %{"data" => "noPaths"}, "type" => "text" }, %{ "clip" => false, - "description" => "The parent node", "encode" => %{ "update" => %{ "cornerRadius" => %{"value" => 2}, "cursor" => %{"signal" => "datum.children > 0 ? 'pointer' : '' "}, "fill" => %{"signal" => "datum.isLeaf ? leafColor : splitColor"}, "height" => %{"signal" => "scaledNodeHeight"}, - "opacity" => %{"value" => 0}, + "opacity" => %{"value" => 1}, "tooltip" => %{"signal" => ""}, "width" => %{"signal" => "scaledNodeWidth"}, "x" => %{"signal" => "datum.xscaled - (scaledNodeWidth / 2)"}, "yc" => %{"signal" => "scale('yscale',datum.y) - (scaledNodeHeight/2)"} } }, - "from" => %{"data" => "visibleNodes"}, + "from" => %{"data" => "splitNodes"}, + "name" => "splitNode", "marks" => [ - %{ - "description" => - "highlight (seems like a Vega bug as this doens't work on the group element)", - "encode" => %{ - "update" => %{ - "fill" => %{"signal" => "parent.isLeaf ? leafColor : splitColor"}, - "height" => %{"signal" => "item.mark.group.height"}, - "width" => %{"signal" => "item.mark.group.width"}, - "x" => %{"signal" => "item.mark.group.x1"}, - "y" => %{"signal" => "0"} - } - }, - "interactive" => false, - "name" => "highlight", - "type" => "rect" - }, %{ "encode" => %{ "update" => %{ @@ -632,11 +682,56 @@ defmodule EXGBoost.Plotting do } }, "interactive" => false, - "name" => "node children", "type" => "text" } ], - "name" => "node", + "type" => "group" + }, + %{ + "clip" => false, + "encode" => %{ + "update" => %{ + "cornerRadius" => %{"value" => 2}, + "cursor" => %{"signal" => "datum.children > 0 ? 'pointer' : '' "}, + "fill" => %{"signal" => "datum.isLeaf ? leafColor : splitColor"}, + "height" => %{"signal" => "scaledNodeHeight"}, + "opacity" => %{"value" => 1}, + "tooltip" => %{"signal" => ""}, + "width" => %{"signal" => "scaledNodeWidth"}, + "x" => %{"signal" => "datum.xscaled - (scaledNodeWidth / 2)"}, + "yc" => %{"signal" => "scale('yscale',datum.y) - (scaledNodeHeight/2)"} + } + }, + "from" => %{"data" => "leafNodes"}, + "name" => "leafNode", + "marks" => [ + %{ + "encode" => %{ + "update" => + Map.merge( + %{ + # "align" => %{"value" => "center"}, + # "baseline" => %{"value" => "middle"}, + # "fill" => %{"signal" => "'black'"}, + # "font" => %{"value" => "Calibri"}, + # "fontSize" => %{ + # "signal" => "parent.split ? splitsFontSize : leavesFontSize" + # }, + "limit" => %{"signal" => "scaledNodeWidth-scaledLimit"}, + "text" => %{ + "signal" => "'leaf = ' + format(parent.leaf, '.2f')" + }, + "x" => %{"signal" => "(scaledNodeWidth / 2)"}, + "y" => %{"signal" => "scaledNodeHeight / 2"} + }, + format_mark(opts[:leaves]) + ) + }, + "interactive" => false, + "name" => "title", + "type" => "text" + } + ], "type" => "group" } ], @@ -697,7 +792,7 @@ defmodule EXGBoost.Plotting do "name" => "node", "on" => [ %{ - "events" => %{"markname" => "node", "type" => "click"}, + "events" => %{"markname" => "splitNode", "type" => "click"}, "update" => "datum.nodeid" } ], @@ -707,7 +802,10 @@ defmodule EXGBoost.Plotting do "name" => "nodeHighlight", "on" => [ %{ - "events" => %{"markname" => "node", "type" => "mouseover"}, + "events" => [ + %{"markname" => "splitNode", "type" => "mouseover"}, + %{"markname" => "leafNode", "type" => "mouseover"} + ], "update" => "pluck(treeAncestors('treeCalcs', datum.nodeid), 'nodeid')" }, %{"events" => %{"type" => "mouseout"}, "update" => "[0]"} @@ -718,7 +816,7 @@ defmodule EXGBoost.Plotting do "name" => "isExpanded", "on" => [ %{ - "events" => %{"markname" => "node", "type" => "click"}, + "events" => %{"markname" => "splitNode", "type" => "click"}, "update" => "datum.children > 0 && indata('treeClickStorePerm', 'nodeid', datum.childrenIds[0]) ? true : false" } @@ -872,6 +970,20 @@ defmodule EXGBoost.Plotting do key |> to_string() |> snake_to_camel() end + defp format_mark(opts) do + opts + |> opts_to_vl_props() + |> Enum.into(%{}, fn {key, value} -> + case key do + "fontSize" -> + {"fontSize", %{"signal" => "(#{value}/ span(xdom))*width"}} + + other -> + {other, %{"value" => value}} + end + end) + end + defp snake_to_camel(string) do [part | parts] = String.split(string, "_") Enum.join([String.downcase(part, :ascii) | Enum.map(parts, &capitalize/1)]) From 0356a2b54681f0b12973f635fbd94257b49e3309 Mon Sep 17 00:00:00 2001 From: acalejos Date: Thu, 14 Dec 2023 23:23:33 -0500 Subject: [PATCH 05/21] Finish plotting spec and opts --- lib/exgboost/plotting.ex | 422 ++++++++++++++++++++++++--------------- 1 file changed, 261 insertions(+), 161 deletions(-) diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex index df45ac6..5b5b67a 100644 --- a/lib/exgboost/plotting.ex +++ b/lib/exgboost/plotting.ex @@ -81,7 +81,78 @@ defmodule EXGBoost.Plotting do |> Jason.decode!() |> ExJsonSchema.Schema.resolve() - @mark_params [] + @mark_text_doc "Accepts a keyword list of Vega `text` Mark properties. Reference [here](https://vega.github.io/vega/docs/marks/text/) for more details. Accepts either a string (expected to be valid Vega property names) or Elixir-styled atom. Note that keys are snake-cased instead of camel-case (e.g. Vega `fontSize` becomes `font_size`)" + @mark_rect_doc "Accepts a keyword list of Vega `rect` Mark properties. Reference [here](https://vega.github.io/vega/docs/marks/rect/) for more details. Accepts either a string (expected to be valid Vega property names) or Elixir-styled atom. Note that keys are snake-cased instead of camel-case (e.g. Vega `fontSize` becomes `font_size`)" + @mark_path_doc "Accepts a keyword list of Vega `path` Mark properties. Reference [here](https://vega.github.io/vega/docs/marks/path/) for more details. Accepts either a string (expected to be valid Vega property names) or Elixir-styled atom. Note that keys are snake-cased instead of camel-case (e.g. Vega `fontSize` becomes `font_size`)" + + @mark_opts [ + type: :keyword_list, + keys: [ + *: [type: {:or, [:string, :atom, :integer, :boolean, nil]}] + ] + ] + + @default_leaf_text [ + align: :center, + baseline: :middle, + font_size: 13, + font: "Calibri", + fill: :black + ] + + @default_leaf_rect [ + corner_radius: 2, + fill: :teal, + opacity: 1 + ] + + @default_split_text [ + align: :center, + baseline: :middle, + font_size: 13, + font: "Calibri", + fill: :black + ] + + @default_split_rect [ + corner_radius: 2, + fill: :teal, + opacity: 1 + ] + + @default_split_children [ + align: :right, + baseline: :middle, + fill: :black, + font: "Calibri", + font_size: 13 + ] + + @default_yes_path [ + stroke: :red + ] + + @default_yes_text [ + align: :center, + baseline: :middle, + font_size: 13, + font: "Calibri", + fill: :black, + text: "yes" + ] + + @default_no_path [ + stroke: :black + ] + + @default_no_text [ + align: :center, + baseline: :middle, + font_size: 13, + font: "Calibri", + fill: :black, + text: "yes" + ] @plotting_params [ autosize: [ @@ -126,12 +197,51 @@ defmodule EXGBoost.Plotting do doc: "Specifies characteristics of leaf nodes", type: :keyword_list, keys: [ - text: [], - rect: [] + text: + @mark_opts ++ + [ + doc: @mark_text_doc, + default: @default_leaf_text + ], + rect: + @mark_opts ++ + [ + doc: @mark_rect_doc, + default: @default_leaf_rect + ] + ], + default: [ + text: @default_leaf_text, + rect: @default_leaf_rect + ] + ], + splits: [ + doc: "Specifies characteristics of split nodes", + type: :keyword_list, + keys: [ + text: + @mark_opts ++ + [ + doc: @mark_text_doc, + default: @default_split_text + ], + rect: + @mark_opts ++ + [ + doc: @mark_rect_doc, + default: @default_split_rect + ], + children: + @mark_opts ++ + [ + doc: @mark_text_doc, + default: @default_split_children + ] ], default: [ - text: [], - rect: [] + text: @default_split_text, + rect: @default_split_rect, + children: @default_split_children ] ], node_width: [ @@ -152,28 +262,48 @@ defmodule EXGBoost.Plotting do ], default: [nodes: 10, levels: 100] ], - splits: [ - doc: "Specifies characteristics of inner nodes", + yes: [ + doc: "Specifies characteristics of links between nodes where the split condition is true", type: :keyword_list, keys: [ - fill: [type: :string, default: "#ccc"], - stroke: [type: :string, default: "black"], - stroke_width: [type: :pos_integer, default: 1], - font_size: [type: :pos_integer, default: 13] + path: + @mark_opts ++ + [ + doc: @mark_path_doc, + default: @default_yes_path + ], + text: + @mark_opts ++ + [ + doc: @mark_text_doc, + default: @default_yes_text + ] ], default: [ - fill: "#ccc", - stroke: "black", - stroke_width: 1, - font_size: 13 + path: @default_yes_path, + text: @default_yes_text ] ], - links: [ - doc: "Specifies characteristics of links between nodes", + no: [ + doc: "Specifies characteristics of links between nodes where the split condition is false", type: :keyword_list, keys: [ - yes: [], - no: [] + path: + @mark_opts ++ + [ + doc: @mark_path_doc, + default: @default_no_path + ], + text: + @mark_opts ++ + [ + doc: @mark_text_doc, + default: @default_no_text + ] + ], + default: [ + path: @default_no_path, + text: @default_no_text ] ], validate: [ @@ -465,33 +595,23 @@ defmodule EXGBoost.Plotting do %{"as" => "isLeaf", "expr" => "datum.leaf == null", "type" => "formula"} ] }, - %{ - "name" => "visibleNodes", - "source" => "fullTreeLayout", - "transform" => [ - %{ - "expr" => "indata('treeClickStorePerm', 'nodeid', datum.nodeid)", - "type" => "filter" - } - ] - }, %{ "name" => "splitNodes", - "source" => "visibleNodes", + "source" => "fullTreeLayout", "transform" => [ %{ "type" => "filter", - "expr" => "datum.isLeaf" + "expr" => "indata('treeClickStorePerm', 'nodeid', datum.nodeid) && datum.isLeaf" } ] }, %{ "name" => "leafNodes", - "source" => "visibleNodes", + "source" => "fullTreeLayout", "transform" => [ %{ "type" => "filter", - "expr" => "!datum.isLeaf" + "expr" => "indata('treeClickStorePerm', 'nodeid', datum.nodeid) && !datum.isLeaf" } ] }, @@ -558,32 +678,34 @@ defmodule EXGBoost.Plotting do "$schema" => "https://vega.github.io/schema/vega/v5.json", "marks" => [ %{ - "encode" => %{ - "update" => %{ - "path" => %{"field" => "path"}, - "stroke" => %{ - "signal" => "datum.source.yes === datum.target.nodeid ? 'red' : 'black'" + "encode" => + Map.merge( + %{ + "update" => %{ + "path" => %{"field" => "path"}, + "strokeWidth" => %{ + "signal" => "indexof(nodeHighlight, datum.target.nodeid)> -1? 2:1" + } + } }, - "strokeWidth" => %{ - "signal" => "indexof(nodeHighlight, datum.target.nodeid)> -1? 2:1" - } - } - }, + format_mark(opts[:yes][:path]) + ), "from" => %{"data" => "yesPaths"}, "interactive" => false, "type" => "path" }, %{ "encode" => %{ - "update" => %{ - "path" => %{"field" => "path"}, - "stroke" => %{ - "signal" => "datum.source.yes === datum.target.nodeid ? 'red' : 'black'" - }, - "strokeWidth" => %{ - "signal" => "indexof(nodeHighlight, datum.target.nodeid)> -1? 2:1" - } - } + "update" => + Map.merge( + %{ + "path" => %{"field" => "path"}, + "strokeWidth" => %{ + "signal" => "indexof(nodeHighlight, datum.target.nodeid)> -1? 2:1" + } + }, + format_mark(opts[:no][:path]) + ) }, "from" => %{"data" => "noPaths"}, "interactive" => false, @@ -591,40 +713,40 @@ defmodule EXGBoost.Plotting do }, %{ "encode" => %{ - "update" => %{ - "align" => %{"value" => "center"}, - "baseline" => %{"value" => "middle"}, - "fontSize" => %{"signal" => "linksFontSize"}, - "text" => %{"signal" => "datum.source.yes === datum.target.nodeid ? 'yes' : 'no'"}, - "x" => %{ - "signal" => - "(scale('xscale', datum.source.x+(nodeWidth/3)) + scale('xscale', datum.target.x)) / 2" - }, - "y" => %{ - "signal" => - "(scale('yscale', datum.source.y) + scale('yscale', datum.target.y)) / 2 - (scaledNodeHeight/2)" - } - } + "update" => + Map.merge( + %{ + "x" => %{ + "signal" => + "(scale('xscale', datum.source.x+(nodeWidth/3)) + scale('xscale', datum.target.x)) / 2" + }, + "y" => %{ + "signal" => + "(scale('yscale', datum.source.y) + scale('yscale', datum.target.y)) / 2 - (scaledNodeHeight/2)" + } + }, + format_mark(opts[:yes][:text]) + ) }, "from" => %{"data" => "yesPaths"}, "type" => "text" }, %{ "encode" => %{ - "update" => %{ - "align" => %{"value" => "center"}, - "baseline" => %{"value" => "middle"}, - "fontSize" => %{"signal" => "linksFontSize"}, - "text" => %{"signal" => "datum.source.yes === datum.target.nodeid ? 'yes' : 'no'"}, - "x" => %{ - "signal" => - "(scale('xscale', datum.source.x+(nodeWidth/3)) + scale('xscale', datum.target.x)) / 2" - }, - "y" => %{ - "signal" => - "(scale('yscale', datum.source.y) + scale('yscale', datum.target.y)) / 2 - (scaledNodeHeight/2)" - } - } + "update" => + Map.merge( + %{ + "x" => %{ + "signal" => + "(scale('xscale', datum.source.x+(nodeWidth/3)) + scale('xscale', datum.target.x)) / 2" + }, + "y" => %{ + "signal" => + "(scale('yscale', datum.source.y) + scale('yscale', datum.target.y)) / 2 - (scaledNodeHeight/2)" + } + }, + format_mark(opts[:no][:text]) + ) }, "from" => %{"data" => "noPaths"}, "type" => "text" @@ -632,37 +754,37 @@ defmodule EXGBoost.Plotting do %{ "clip" => false, "encode" => %{ - "update" => %{ - "cornerRadius" => %{"value" => 2}, - "cursor" => %{"signal" => "datum.children > 0 ? 'pointer' : '' "}, - "fill" => %{"signal" => "datum.isLeaf ? leafColor : splitColor"}, - "height" => %{"signal" => "scaledNodeHeight"}, - "opacity" => %{"value" => 1}, - "tooltip" => %{"signal" => ""}, - "width" => %{"signal" => "scaledNodeWidth"}, - "x" => %{"signal" => "datum.xscaled - (scaledNodeWidth / 2)"}, - "yc" => %{"signal" => "scale('yscale',datum.y) - (scaledNodeHeight/2)"} - } + "update" => + Map.merge( + %{ + "cursor" => %{"signal" => "datum.children > 0 ? 'pointer' : '' "}, + "height" => %{"signal" => "scaledNodeHeight"}, + "tooltip" => %{"signal" => ""}, + "width" => %{"signal" => "scaledNodeWidth"}, + "x" => %{"signal" => "datum.xscaled - (scaledNodeWidth / 2)"}, + "yc" => %{"signal" => "scale('yscale',datum.y) - (scaledNodeHeight/2)"} + }, + format_mark(opts[:splits][:rect]) + ) }, "from" => %{"data" => "splitNodes"}, "name" => "splitNode", "marks" => [ %{ "encode" => %{ - "update" => %{ - "align" => %{"value" => "center"}, - "baseline" => %{"value" => "middle"}, - "fill" => %{"signal" => "'black'"}, - "font" => %{"value" => "Calibri"}, - "fontSize" => %{"signal" => "parent.split ? splitsFontSize : leavesFontSize"}, - "limit" => %{"signal" => "scaledNodeWidth-scaledLimit"}, - "text" => %{ - "signal" => - "parent.split ? parent.split + ' <= ' + format(parent.split_condition, '.2f') : 'leaf = ' + format(parent.leaf, '.2f')" - }, - "x" => %{"signal" => "(scaledNodeWidth / 2)"}, - "y" => %{"signal" => "scaledNodeHeight / 2"} - } + "update" => + Map.merge( + %{ + "limit" => %{"signal" => "scaledNodeWidth-scaledLimit"}, + "text" => %{ + "signal" => + "parent.split + ' <= ' + format(parent.split_condition, '.2f')" + }, + "x" => %{"signal" => "(scaledNodeWidth / 2)"}, + "y" => %{"signal" => "scaledNodeHeight / 2"} + }, + format_mark(opts[:splits][:text]) + ) }, "interactive" => false, "name" => "title", @@ -670,16 +792,15 @@ defmodule EXGBoost.Plotting do }, %{ "encode" => %{ - "update" => %{ - "align" => %{"value" => "right"}, - "baseline" => %{"value" => "middle"}, - "fill" => %{"signal" => "childrenColor"}, - "font" => %{"value" => "Calibri"}, - "fontSize" => %{"signal" => "splitsFontSize"}, - "text" => %{"signal" => "parent.children > 0 ? parent.children :''"}, - "x" => %{"signal" => "item.mark.group.width - (9/ span(xdom))*width"}, - "y" => %{"signal" => "item.mark.group.height/2"} - } + "update" => + Map.merge( + %{ + "text" => %{"signal" => "parent.children"}, + "x" => %{"signal" => "item.mark.group.width - (9/ span(xdom))*width"}, + "y" => %{"signal" => "item.mark.group.height/2"} + }, + format_mark(opts[:splits][:children]) + ) }, "interactive" => false, "type" => "text" @@ -690,17 +811,18 @@ defmodule EXGBoost.Plotting do %{ "clip" => false, "encode" => %{ - "update" => %{ - "cornerRadius" => %{"value" => 2}, - "cursor" => %{"signal" => "datum.children > 0 ? 'pointer' : '' "}, - "fill" => %{"signal" => "datum.isLeaf ? leafColor : splitColor"}, - "height" => %{"signal" => "scaledNodeHeight"}, - "opacity" => %{"value" => 1}, - "tooltip" => %{"signal" => ""}, - "width" => %{"signal" => "scaledNodeWidth"}, - "x" => %{"signal" => "datum.xscaled - (scaledNodeWidth / 2)"}, - "yc" => %{"signal" => "scale('yscale',datum.y) - (scaledNodeHeight/2)"} - } + "update" => + Map.merge( + %{ + "cursor" => %{"signal" => "datum.children > 0 ? 'pointer' : '' "}, + "height" => %{"signal" => "scaledNodeHeight"}, + "tooltip" => %{"signal" => ""}, + "width" => %{"signal" => "scaledNodeWidth"}, + "x" => %{"signal" => "datum.xscaled - (scaledNodeWidth / 2)"}, + "yc" => %{"signal" => "scale('yscale',datum.y) - (scaledNodeHeight/2)"} + }, + format_mark(opts[:leaves][:rect]) + ) }, "from" => %{"data" => "leafNodes"}, "name" => "leafNode", @@ -710,13 +832,6 @@ defmodule EXGBoost.Plotting do "update" => Map.merge( %{ - # "align" => %{"value" => "center"}, - # "baseline" => %{"value" => "middle"}, - # "fill" => %{"signal" => "'black'"}, - # "font" => %{"value" => "Calibri"}, - # "fontSize" => %{ - # "signal" => "parent.split ? splitsFontSize : leavesFontSize" - # }, "limit" => %{"signal" => "scaledNodeWidth-scaledLimit"}, "text" => %{ "signal" => "'leaf = ' + format(parent.leaf, '.2f')" @@ -724,7 +839,7 @@ defmodule EXGBoost.Plotting do "x" => %{"signal" => "(scaledNodeWidth / 2)"}, "y" => %{"signal" => "scaledNodeHeight / 2"} }, - format_mark(opts[:leaves]) + format_mark(opts[:leaves][:text]) ) }, "interactive" => false, @@ -771,23 +886,6 @@ defmodule EXGBoost.Plotting do } ) ), - %{"name" => "childrenColor", "value" => "black"}, - %{"name" => "leafColor", "value" => opts[:leaves][:fill]}, - %{"name" => "splitColor", "value" => opts[:splits][:fill]}, - %{"name" => "spaceBetweenLevels", "value" => opts[:space_between][:levels]}, - %{"name" => "spaceBetweenNodes", "value" => opts[:space_between][:nodes]}, - %{"name" => "nodeWidth", "value" => opts[:node_width]}, - %{"name" => "nodeHeight", "value" => opts[:node_height]}, - %{ - "name" => "startingDepth", - "on" => [%{"events" => %{"throttle" => 0, "type" => "timer"}, "update" => "-1"}], - "value" => - opts[:depth] || - to_tabular(booster) - |> Enum.map(&Map.get(&1, "depth")) - |> Enum.filter(& &1) - |> Enum.max() - }, %{ "name" => "node", "on" => [ @@ -915,17 +1013,19 @@ defmodule EXGBoost.Plotting do %{"name" => "scaledNodeWidth", "update" => "(nodeWidth/ span(xdom))*width"}, %{"name" => "scaledNodeHeight", "update" => "abs(nodeHeight/ span(ydom))*height"}, %{"name" => "scaledLimit", "update" => "(20/ span(xdom))*width"}, + %{"name" => "spaceBetweenLevels", "value" => opts[:space_between][:levels]}, + %{"name" => "spaceBetweenNodes", "value" => opts[:space_between][:nodes]}, + %{"name" => "nodeWidth", "value" => opts[:node_width]}, + %{"name" => "nodeHeight", "value" => opts[:node_height]}, %{ - "name" => "leavesFontSize", - "update" => "(#{opts[:leaves][:font_size]}/ span(xdom))*width" - }, - %{ - "name" => "splitsFontSize", - "update" => "(#{opts[:splits][:font_size]}/ span(xdom))*width" - }, - %{ - "name" => "linksFontSize", - "update" => "(#{opts[:links][:font_size]}/ span(xdom))*width" + "name" => "startingDepth", + "on" => [%{"events" => %{"throttle" => 0, "type" => "timer"}, "update" => "-1"}], + "value" => + opts[:depth] || + to_tabular(booster) + |> Enum.map(&Map.get(&1, "depth")) + |> Enum.filter(& &1) + |> Enum.max() } ] }) From cbdcdda572d7192109c733ef4d973f278a9691be Mon Sep 17 00:00:00 2001 From: acalejos Date: Tue, 26 Dec 2023 13:09:18 -0500 Subject: [PATCH 06/21] Update plotting --- lib/exgboost.ex | 2 +- lib/exgboost/booster.ex | 12 ++++ lib/exgboost/plotting.ex | 151 +++++++++++++++++++++++++++++++-------- lib/exgboost/training.ex | 4 +- 4 files changed, 138 insertions(+), 31 deletions(-) diff --git a/lib/exgboost.ex b/lib/exgboost.ex index 9b3bb07..7755b61 100644 --- a/lib/exgboost.ex +++ b/lib/exgboost.ex @@ -212,7 +212,7 @@ defmodule EXGBoost do def train(x, y, opts \\ []) do x = Nx.concatenate(x) y = Nx.concatenate(y) - {dmat_opts, opts} = Keyword.split(opts, Internal.dmatrix_feature_opts()) + dmat_opts = Keyword.take(opts, Internal.dmatrix_feature_opts()) dmat = DMatrix.from_tensor(x, y, Keyword.put_new(dmat_opts, :format, :dense)) Training.train(dmat, opts) end diff --git a/lib/exgboost/booster.ex b/lib/exgboost/booster.ex index f069cdc..f60b8ed 100644 --- a/lib/exgboost/booster.ex +++ b/lib/exgboost/booster.ex @@ -163,9 +163,15 @@ defmodule EXGBoost.Booster do def booster(dmats, opts \\ []) def booster(dmats, opts) when is_list(dmats) do + {str_opts, opts} = Keyword.split(opts, Internal.dmatrix_str_feature_opts()) opts = EXGBoost.Parameters.validate!(opts) refs = Enum.map(dmats, & &1.ref) booster_ref = EXGBoost.NIF.booster_create(refs) |> Internal.unwrap!() + + Enum.each(str_opts, fn {key, value} -> + EXGBoost.NIF.booster_set_str_feature_info(booster_ref, Atom.to_string(key), value) + end) + set_params(%__MODULE__{ref: booster_ref}, opts) end @@ -174,9 +180,15 @@ defmodule EXGBoost.Booster do end def booster(%__MODULE__{} = bst, opts) do + {str_opts, opts} = Keyword.split(opts, Internal.dmatrix_str_feature_opts()) opts = EXGBoost.Parameters.validate!(opts) boostr_bytes = EXGBoost.NIF.booster_serialize_to_buffer(bst.ref) |> Internal.unwrap!() booster_ref = EXGBoost.NIF.booster_deserialize_from_buffer(boostr_bytes) |> Internal.unwrap!() + + Enum.each(str_opts, fn {key, value} -> + EXGBoost.NIF.booster_set_str_feature_info(booster_ref, Atom.to_string(key), value) + end) + set_params(%__MODULE__{ref: booster_ref}, opts) end diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex index 5b5b67a..c684612 100644 --- a/lib/exgboost/plotting.ex +++ b/lib/exgboost/plotting.ex @@ -1,22 +1,3 @@ -# defmodule EXGBoost.Plotting.Schema do -# @moduledoc false -# Code.ensure_compiled!(ExJsonSchema) - -# defmacro __before_compile__(_env) do -# HTTPoison.start() - -# schema = -# HTTPoison.get!("https://vega.github.io/schema/vega/v5.json").body -# |> Jason.decode!() -# |> ExJsonSchema.Schema.resolve() -# |> Macro.escape() - -# quote do -# def get_schema(), do: unquote(schema) -# end -# end -# end - defmodule EXGBoost.Plotting do @moduledoc """ Functions for plotting EXGBoost `Booster` models using [Vega](https://vega.github.io/vega/) @@ -151,10 +132,15 @@ defmodule EXGBoost.Plotting do font_size: 13, font: "Calibri", fill: :black, - text: "yes" + text: "no" ] @plotting_params [ + rankdir: [ + doc: "Determines the direction of the graph.", + type: {:in, [:tb, :lr]}, + default: :tb + ], autosize: [ doc: "Determines if the visualization should automatically resize when the window size changes", @@ -327,8 +313,12 @@ defmodule EXGBoost.Plotting do @plotting_schema NimbleOptions.new!(@plotting_params) + @defaults NimbleOptions.validate!([], @plotting_schema) + def get_schema(), do: @schema + def get_defaults(), do: @defaults + defp validate_spec(spec) do case ExJsonSchema.Validator.validate(get_schema(), spec) do :ok -> @@ -418,7 +408,9 @@ defmodule EXGBoost.Plotting do This function is useful if you want to create a custom Vega specification, but still want to use the data transformation provided by the default specification. """ - def get_data_spec(booster) do + def get_data_spec(booster, opts \\ []) do + opts = NimbleOptions.validate!(opts, @plotting_schema) + %{ "$schema" => "https://vega.github.io/schema/vega/v5.json", "data" => [ @@ -430,7 +422,14 @@ defmodule EXGBoost.Plotting do %{"expr" => "datum.tree_id === selectedTree", "type" => "filter"}, %{"key" => "nodeid", "parentKey" => "parentid", "type" => "stratify"}, %{ - "as" => ["y", "x", "depth", "children"], + "as" => + case opts[:rankdir] do + :tb -> + ["x", "y", "depth", "children"] + + :lr -> + ["y", "x", "depth", "children"] + end, "method" => "tidy", "separation" => %{"signal" => "false"}, "type" => "tree" @@ -547,7 +546,14 @@ defmodule EXGBoost.Plotting do }, %{"key" => "nodeid", "parentKey" => "parentid", "type" => "stratify"}, %{ - "as" => ["x", "y", "depth", "children"], + "as" => + case opts[:rankdir] do + :tb -> + ["x", "y", "depth", "children"] + + :lr -> + ["y", "x", "depth", "children"] + end, "method" => "tidy", "nodeSize" => [ %{"signal" => "nodeWidth + spaceBetweenNodes"}, @@ -621,12 +627,55 @@ defmodule EXGBoost.Plotting do "transform" => [ %{"type" => "treelinks"}, %{ - "orient" => "vertical", + "orient" => + case opts[:rankdir] do + vertical when vertical in [:tb, :bt] -> + "vertical" + + horizontal when horizontal in [:lr, :rl] -> + "horizontal" + end, "shape" => "line", - "sourceX" => %{"expr" => "scale('xscale', datum.source.x)"}, - "sourceY" => %{"expr" => "scale('yscale', datum.source.y)"}, - "targetX" => %{"expr" => "scale('xscale', datum.target.x)"}, - "targetY" => %{"expr" => "scale('yscale', datum.target.y) - scaledNodeHeight"}, + "sourceX" => %{ + "expr" => + case opts[:rankdir] do + :tb -> + "scale('xscale', datum.source.x)" + + :lr -> + "scale('xscale', datum.source.x) + scaledNodeWidth/2" + end + }, + "sourceY" => %{ + "expr" => + case opts[:rankdir] do + :tb -> + "scale('yscale', datum.source.y)" + + :lr -> + "scale('yscale', datum.source.y) - scaledNodeHeight/2" + end + }, + "targetX" => %{ + "expr" => + case opts[:rankdir] do + :tb -> + "scale('xscale', datum.target.x)" + + :lr -> + "scale('xscale', datum.target.x) - scaledNodeWidth/2" + end + }, + "targetY" => %{ + "expr" => + case opts[:rankdir] do + :tb -> + "scale('yscale', datum.target.y) - scaledNodeHeight" + + :lr -> + "scale('yscale', datum.target.y) - scaledNodeHeight/2" + end + }, "type" => "linkpath" }, %{ @@ -660,8 +709,52 @@ defmodule EXGBoost.Plotting do end def to_vega(booster, opts \\ []) do + spec = get_data_spec(booster, opts) opts = NimbleOptions.validate!(opts, @plotting_schema) - spec = get_data_spec(booster) + defaults = get_defaults() + + # Try to account for non-default node height / width to adjust spacing + # between nodes and levels as a quality of life improvement + # If they provide their own spacing, don't adjust it + opts = + cond do + opts[:rankdir] == :lr and + opts[:space_between][:levels] == defaults[:space_between][:levels] -> + put_in( + opts[:space_between][:levels], + defaults[:space_between][:levels] + (opts[:node_width] - defaults[:node_width]) + ) + + opts[:rankdir] == :tb and + opts[:space_between][:levels] == defaults[:space_between][:levels] -> + put_in( + opts[:space_between][:levels], + defaults[:space_between][:levels] + (opts[:node_height] - defaults[:node_height]) + ) + + true -> + opts + end + + opts = + cond do + opts[:rankdir] == :lr and + opts[:space_between][:nodes] == defaults[:space_between][:nodes] -> + put_in( + opts[:space_between][:nodes], + defaults[:space_between][:nodes] + (opts[:node_height] - defaults[:node_height]) + ) + + opts[:rankdir] == :tb and + opts[:space_between][:nodes] == defaults[:space_between][:nodes] -> + put_in( + opts[:space_between][:nodes], + defaults[:space_between][:nodes] + (opts[:node_width] - defaults[:node_width]) + ) + + true -> + opts + end [%{"$ref" => root} | [%{"properties" => properties}]] = EXGBoost.Plotting.get_schema().schema |> Map.get("allOf") diff --git a/lib/exgboost/training.ex b/lib/exgboost/training.ex index 41f5a76..0fdea1b 100644 --- a/lib/exgboost/training.ex +++ b/lib/exgboost/training.ex @@ -6,6 +6,8 @@ defmodule EXGBoost.Training do @spec train(DMatrix.t(), Keyword.t()) :: Booster.t() def train(%DMatrix{} = dmat, opts \\ []) do + dmat_opts = Keyword.take(opts, EXGBoost.Internal.dmatrix_feature_opts()) + {opts, booster_params} = Keyword.split(opts, [ :obj, @@ -40,7 +42,7 @@ defmodule EXGBoost.Training do evals_dmats = Enum.map(evals, fn {%Nx.Tensor{} = x, %Nx.Tensor{} = y, name} -> - {DMatrix.from_tensor(x, y, format: :dense), name} + {DMatrix.from_tensor(x, y, Keyword.put_new(dmat_opts, :format, :dense)), name} end) bst = From 0edbe26d9224a45fbcc5d1f92331d24b45adc43f Mon Sep 17 00:00:00 2001 From: acalejos Date: Tue, 26 Dec 2023 14:01:29 -0500 Subject: [PATCH 07/21] Add rankdir option --- lib/exgboost/plotting.ex | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex index c684612..0081ce2 100644 --- a/lib/exgboost/plotting.ex +++ b/lib/exgboost/plotting.ex @@ -138,7 +138,7 @@ defmodule EXGBoost.Plotting do @plotting_params [ rankdir: [ doc: "Determines the direction of the graph.", - type: {:in, [:tb, :lr]}, + type: {:in, [:tb, :lr, :bt, :rl]}, default: :tb ], autosize: [ @@ -424,10 +424,10 @@ defmodule EXGBoost.Plotting do %{ "as" => case opts[:rankdir] do - :tb -> + vertical when vertical in [:tb, :bt] -> ["x", "y", "depth", "children"] - :lr -> + horizontal when horizontal in [:lr, :rl] -> ["y", "x", "depth", "children"] end, "method" => "tidy", @@ -548,10 +548,10 @@ defmodule EXGBoost.Plotting do %{ "as" => case opts[:rankdir] do - :tb -> + vertical when vertical in [:tb, :bt] -> ["x", "y", "depth", "children"] - :lr -> + horizontal when horizontal in [:lr, :rl] -> ["y", "x", "depth", "children"] end, "method" => "tidy", @@ -562,8 +562,16 @@ defmodule EXGBoost.Plotting do "separation" => %{"signal" => "false"}, "type" => "tree" }, - %{"as" => "y", "expr" => "datum.y+(height/10)", "type" => "formula"}, - %{"as" => "x", "expr" => "datum.x+(width/2)", "type" => "formula"}, + %{ + "as" => "y", + "expr" => "#{if opts[:rankdir] == :bt, do: -1, else: 1}*(datum.y+(height/10))", + "type" => "formula" + }, + %{ + "as" => "x", + "expr" => "#{if opts[:rankdir] == :rl, do: -1, else: 1}*(datum.x+(width/2))", + "type" => "formula" + }, %{"as" => "xscaled", "expr" => "scale('xscale',datum.x)", "type" => "formula"}, %{"as" => "parent", "expr" => "datum.parentid", "type" => "formula"} ] @@ -639,40 +647,40 @@ defmodule EXGBoost.Plotting do "sourceX" => %{ "expr" => case opts[:rankdir] do - :tb -> + vertical when vertical in [:tb, :bt] -> "scale('xscale', datum.source.x)" - :lr -> + horizontal when horizontal in [:lr, :rl] -> "scale('xscale', datum.source.x) + scaledNodeWidth/2" end }, "sourceY" => %{ "expr" => case opts[:rankdir] do - :tb -> + vertical when vertical in [:tb, :bt] -> "scale('yscale', datum.source.y)" - :lr -> + horizontal when horizontal in [:lr, :rl] -> "scale('yscale', datum.source.y) - scaledNodeHeight/2" end }, "targetX" => %{ "expr" => case opts[:rankdir] do - :tb -> + vertical when vertical in [:tb, :bt] -> "scale('xscale', datum.target.x)" - :lr -> + horizontal when horizontal in [:lr, :rl] -> "scale('xscale', datum.target.x) - scaledNodeWidth/2" end }, "targetY" => %{ "expr" => case opts[:rankdir] do - :tb -> + vertical when vertical in [:tb, :bt] -> "scale('yscale', datum.target.y) - scaledNodeHeight" - :lr -> + horizontal when horizontal in [:lr, :rl] -> "scale('yscale', datum.target.y) - scaledNodeHeight/2" end }, From 2962a417be8d57e60fa74805d401ccd349ba4b96 Mon Sep 17 00:00:00 2001 From: acalejos Date: Thu, 18 Jan 2024 22:06:57 -0500 Subject: [PATCH 08/21] Add support for styles --- Makefile | 1 + lib/exgboost/plotting.ex | 106 ++++++++++-- lib/exgboost/plotting/styles.ex | 293 ++++++++++++++++++++++++++++++++ mix.exs | 2 +- 4 files changed, 387 insertions(+), 15 deletions(-) create mode 100644 lib/exgboost/plotting/styles.ex diff --git a/Makefile b/Makefile index 4520e18..41c14c3 100644 --- a/Makefile +++ b/Makefile @@ -63,6 +63,7 @@ $(XGBOOST_LIB_DIR_FLAG): git fetch --depth 1 --recurse-submodules origin $(XGBOOST_GIT_REV) && \ git checkout FETCH_HEAD && \ git submodule update --init --recursive && \ + sed 's|learner_parameters\["generic_param"\] = ToJson(ctx_);|&\nlearner_parameters\["default_metric"\] = String(obj_->DefaultEvalMetric());|' src/learner.cc > src/learner.cc.tmp && mv src/learner.cc.tmp src/learner.cc && \ cmake -DCMAKE_INSTALL_PREFIX=$(XGBOOST_LIB_DIR) -B build . $(CMAKE_FLAGS) && \ make -C build -j1 install touch $(XGBOOST_LIB_DIR_FLAG) diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex index 0081ce2..87f6f98 100644 --- a/lib/exgboost/plotting.ex +++ b/lib/exgboost/plotting.ex @@ -55,7 +55,6 @@ defmodule EXGBoost.Plotting do """ - HTTPoison.start() @schema HTTPoison.get!("https://vega.github.io/schema/vega/v5.json").body @@ -69,7 +68,7 @@ defmodule EXGBoost.Plotting do @mark_opts [ type: :keyword_list, keys: [ - *: [type: {:or, [:string, :atom, :integer, :boolean, nil]}] + *: [type: {:or, [:string, :atom, :integer, :boolean, nil, {:list, :any}]}] ] ] @@ -136,6 +135,27 @@ defmodule EXGBoost.Plotting do ] @plotting_params [ + style: [ + doc: "The style to use for the visualization.", + default: :solarized_light, + type: + {:in, + [ + :solarized_light, + :solarized_dark, + :playful_light, + :playful_dark, + :dark, + :light, + :nord, + :dracula, + :gruvbox, + :high_contrast, + :monokai, + :material, + :one_dark + ]} + ], rankdir: [ doc: "Determines the direction of the graph.", type: {:in, [:tb, :lr, :bt, :rl]}, @@ -241,10 +261,19 @@ defmodule EXGBoost.Plotting do default: 45 ], space_between: [ + doc: "The space between the rectangular marks in pixels.", type: :keyword_list, keys: [ - nodes: [type: :pos_integer, default: 10], - levels: [type: :pos_integer, default: 100] + nodes: [ + doc: "Space between marks within the same depth of the tree.", + type: :pos_integer, + default: 10 + ], + levels: [ + doc: "Space between each rank / depth of the tree.", + type: :pos_integer, + default: 100 + ] ], default: [nodes: 10, levels: 100] ], @@ -354,6 +383,16 @@ defmodule EXGBoost.Plotting do end end + def deep_merge_kw(a, b) do + Keyword.merge(a, b, fn + _key, val_a, val_b when is_list(val_a) and is_list(val_b) -> + deep_merge_kw(val_a, val_b) + + _key, _val_a, val_b -> + val_b + end) + end + defp new_node(params = %{}) do Map.merge( %{ @@ -386,7 +425,7 @@ defmodule EXGBoost.Plotting do - yes: The node id of the left child - no: The node id of the right child - missing: The node id of the missing child - - depth: The depth of the node + - depth: The depth of the node (root node is depth 1) - leaf: The leaf value if it is a leaf node """ def to_tabular(booster) do @@ -407,10 +446,28 @@ defmodule EXGBoost.Plotting do This function is useful if you want to create a custom Vega specification, but still want to use the data transformation provided by the default specification. + + ## Options + + * `:rankdir` - Determines the direction of the graph. Accepts one of `:tb`, `:lr`, `:bt`, or `:rl`. Defaults to `:tb` + """ def get_data_spec(booster, opts \\ []) do opts = NimbleOptions.validate!(opts, @plotting_schema) + opts = + unless opts[:style] in [nil, false] do + style = + NimbleOptions.validate!( + apply(EXGBoost.Plotting.Styles, opts[:style], []), + @plotting_schema + ) + + deep_merge_kw(opts, style) + else + opts + end + %{ "$schema" => "https://vega.github.io/schema/vega/v5.json", "data" => [ @@ -650,8 +707,11 @@ defmodule EXGBoost.Plotting do vertical when vertical in [:tb, :bt] -> "scale('xscale', datum.source.x)" - horizontal when horizontal in [:lr, :rl] -> + :lr -> "scale('xscale', datum.source.x) + scaledNodeWidth/2" + + :rl -> + "scale('xscale', datum.source.x) - scaledNodeWidth/2" end }, "sourceY" => %{ @@ -670,8 +730,11 @@ defmodule EXGBoost.Plotting do vertical when vertical in [:tb, :bt] -> "scale('xscale', datum.target.x)" - horizontal when horizontal in [:lr, :rl] -> + :lr -> "scale('xscale', datum.target.x) - scaledNodeWidth/2" + + :rl -> + "scale('xscale', datum.target.x) + scaledNodeWidth/2" end }, "targetY" => %{ @@ -719,6 +782,20 @@ defmodule EXGBoost.Plotting do def to_vega(booster, opts \\ []) do spec = get_data_spec(booster, opts) opts = NimbleOptions.validate!(opts, @plotting_schema) + + opts = + unless opts[:style] in [nil, false] do + style = + NimbleOptions.validate!( + apply(EXGBoost.Plotting.Styles, opts[:style], []), + @plotting_schema + ) + + deep_merge_kw(opts, style) + else + opts + end + defaults = get_defaults() # Try to account for non-default node height / width to adjust spacing @@ -726,14 +803,14 @@ defmodule EXGBoost.Plotting do # If they provide their own spacing, don't adjust it opts = cond do - opts[:rankdir] == :lr and + opts[:rankdir] in [:lr, :rl] and opts[:space_between][:levels] == defaults[:space_between][:levels] -> put_in( opts[:space_between][:levels], defaults[:space_between][:levels] + (opts[:node_width] - defaults[:node_width]) ) - opts[:rankdir] == :tb and + opts[:rankdir] in [:tb, :bt] and opts[:space_between][:levels] == defaults[:space_between][:levels] -> put_in( opts[:space_between][:levels], @@ -746,14 +823,14 @@ defmodule EXGBoost.Plotting do opts = cond do - opts[:rankdir] == :lr and + opts[:rankdir] in [:lr, :rl] and opts[:space_between][:nodes] == defaults[:space_between][:nodes] -> put_in( opts[:space_between][:nodes], defaults[:space_between][:nodes] + (opts[:node_height] - defaults[:node_height]) ) - opts[:rankdir] == :tb and + opts[:rankdir] in [:tb, :bt] and opts[:space_between][:nodes] == defaults[:space_between][:nodes] -> put_in( opts[:space_between][:nodes], @@ -937,7 +1014,7 @@ defmodule EXGBoost.Plotting do "text" => %{ "signal" => "'leaf = ' + format(parent.leaf, '.2f')" }, - "x" => %{"signal" => "(scaledNodeWidth / 2)"}, + "x" => %{"signal" => "scaledNodeWidth / 2"}, "y" => %{"signal" => "scaledNodeHeight / 2"} }, format_mark(opts[:leaves][:text]) @@ -1131,8 +1208,9 @@ defmodule EXGBoost.Plotting do ] }) - spec = Map.merge(spec, tlk) - spec = (opts[:validate] && validate_spec(spec)) || spec + spec = Map.merge(spec, tlk) |> Map.delete("style") + File.write!("/Users/andres/Documents/exgboost/spec.json", Jason.encode!(spec)) + spec = if opts[:validate], do: validate_spec(spec), else: spec Jason.encode!(spec) |> VegaLite.from_json() end diff --git a/lib/exgboost/plotting/styles.ex b/lib/exgboost/plotting/styles.ex new file mode 100644 index 0000000..62d4807 --- /dev/null +++ b/lib/exgboost/plotting/styles.ex @@ -0,0 +1,293 @@ +defmodule EXGBoost.Plotting.Styles do + @moduledoc """ + A style is a keyword-map that adheres to the plotting schema + as defined in `EXGBoost.Plotting`. + """ + + def solarized_light(), + do: [ + # base3 + background: "#fdf6e3", + leaves: [ + # base01 + text: [fill: "#586e75", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + # base2, base1 + rect: [fill: "#eee8d5", stroke: "#93a1a1", strokeWidth: 1] + ], + splits: [ + # base01 + text: [fill: "#586e75", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + # base1, base00 + rect: [fill: "#93a1a1", stroke: "#657b83", strokeWidth: 1], + # base00 + children: [fill: "#657b83", stroke: "#657b83", strokeWidth: 1] + ], + yes: [ + # green + text: [fill: "#859900"], + # base00 + path: [stroke: "#657b83", strokeWidth: 1] + ], + no: [ + # red + text: [fill: "#dc322f"], + # base00 + path: [stroke: "#657b83", strokeWidth: 1] + ] + ] + + def solarized_dark(), + do: [ + # base03 + background: "#002b36", + leaves: [ + # base0 + text: [fill: "#839496", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + # base02, base01 + rect: [fill: "#073642", stroke: "#586e75", strokeWidth: 1] + ], + splits: [ + # base0 + text: [fill: "#839496", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + # base01, base00 + rect: [fill: "#586e75", stroke: "#657b83", strokeWidth: 1], + # base00 + children: [fill: "#657b83", stroke: "#657b83", strokeWidth: 1] + ], + yes: [ + # green + text: [fill: "#859900"], + # base00 + path: [stroke: "#657b83", strokeWidth: 1] + ], + no: [ + # red + text: [fill: "#dc322f"], + # base00 + path: [stroke: "#657b83", strokeWidth: 1] + ] + ] + + def playful_light(), + do: [ + background: "#f0f0f0", + padding: 10, + leaves: [ + text: [fill: "#000", font_size: 12, font_style: "italic", font_weight: "bold"], + rect: [fill: "#e91e63", stroke: "#000", stroke_width: 1, radius: 5] + ], + splits: [ + text: [fill: "#000", font_size: 12, font_style: "normal", font_weight: "bold"], + children: [ + fill: "#000", + font_size: 12, + font_style: "normal", + font_weight: "bold" + ], + rect: [fill: "#8bc34a", stroke: "#000", stroke_width: 1, radius: 10] + ], + yes: [ + path: [stroke: "#4caf50", stroke_width: 2] + ], + no: [ + path: [stroke: "#f44336", stroke_width: 2] + ] + ] + + def playful_dark(), + do: [ + background: "#333", + padding: 10, + leaves: [ + text: [fill: "#fff", font_size: 12, font_style: "italic", font_weight: "bold"], + rect: [fill: "#e91e63", stroke: "#fff", stroke_width: 1, radius: 5] + ], + splits: [ + text: [fill: "#fff", font_size: 12, font_style: "normal", font_weight: "bold"], + rect: [fill: "#8bc34a", stroke: "#fff", stroke_width: 1, radius: 10] + ], + yes: [ + text: [fill: "#4caf50"], + path: [stroke: "#4caf50", stroke_width: 2] + ], + no: [ + text: [fill: "#f44336"], + path: [stroke: "#f44336", stroke_width: 2] + ] + ] + + def dark(), + do: [ + background: "#333", + padding: 10, + leaves: [ + text: [fill: "#fff", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#666", stroke: "#fff", strokeWidth: 1] + ], + splits: [ + text: [fill: "#fff", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#444", stroke: "#fff", strokeWidth: 1], + children: [fill: "#fff", stroke: "#fff", strokeWidth: 1] + ] + ] + + def high_contrast(), + do: [ + background: "#000", + padding: 10, + leaves: [ + text: [fill: "#fff", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#333", stroke: "#fff", strokeWidth: 1] + ], + splits: [ + text: [fill: "#fff", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#666", stroke: "#fff", strokeWidth: 1] + ] + ] + + def light(), + do: [ + background: "#f0f0f0", + padding: 10, + leaves: [ + text: [fill: "#000", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#ddd", stroke: "#000", strokeWidth: 1] + ], + splits: [ + text: [fill: "#000", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#bbb", stroke: "#000", strokeWidth: 1] + ] + ] + + def monokai(), + do: [ + background: "#272822", + leaves: [ + text: [fill: "#f8f8f2", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#3e3d32", stroke: "#66d9ef", strokeWidth: 1] + ], + splits: [ + text: [fill: "#f8f8f2", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#66d9ef", stroke: "#f8f8f2", strokeWidth: 1], + children: [fill: "#f8f8f2", stroke: "#f8f8f2", strokeWidth: 1] + ], + yes: [ + text: [fill: "#a6e22e"], + path: [stroke: "#f8f8f2", strokeWidth: 1] + ], + no: [ + text: [fill: "#f92672"], + path: [stroke: "#f8f8f2", strokeWidth: 1] + ] + ] + + def dracula(), + do: [ + background: "#282a36", + leaves: [ + text: [fill: "#f8f8f2", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#44475a", stroke: "#ff79c6", strokeWidth: 1] + ], + splits: [ + text: [fill: "#f8f8f2", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#ff79c6", stroke: "#f8f8f2", strokeWidth: 1], + children: [fill: "#f8f8f2", stroke: "#f8f8f2", strokeWidth: 1] + ], + yes: [ + text: [fill: "#50fa7b"], + path: [stroke: "#f8f8f2", strokeWidth: 1] + ], + no: [ + text: [fill: "#ff5555"], + path: [stroke: "#f8f8f2", strokeWidth: 1] + ] + ] + + def nord(), + do: [ + background: "#2e3440", + leaves: [ + text: [fill: "#d8dee9", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#3b4252", stroke: "#88c0d0", strokeWidth: 1] + ], + splits: [ + text: [fill: "#d8dee9", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#88c0d0", stroke: "#d8dee9", strokeWidth: 1], + children: [fill: "#d8dee9", stroke: "#d8dee9", strokeWidth: 1] + ], + yes: [ + text: [fill: "#a3be8c"], + path: [stroke: "#d8dee9", strokeWidth: 1] + ], + no: [ + text: [fill: "#bf616a"], + path: [stroke: "#d8dee9", strokeWidth: 1] + ] + ] + + def material(), + do: [ + background: "#263238", + leaves: [ + text: [fill: "#eceff1", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#37474f", stroke: "#80cbc4", strokeWidth: 1] + ], + splits: [ + text: [fill: "#eceff1", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#80cbc4", stroke: "#eceff1", strokeWidth: 1], + children: [fill: "#eceff1", stroke: "#eceff1", strokeWidth: 1] + ], + yes: [ + text: [fill: "#c5e1a5"], + path: [stroke: "#eceff1", strokeWidth: 1] + ], + no: [ + text: [fill: "#ef9a9a"], + path: [stroke: "#eceff1", strokeWidth: 1] + ] + ] + + def one_dark(), + do: [ + background: "#282c34", + leaves: [ + text: [fill: "#abb2bf", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#3b4048", stroke: "#98c379", strokeWidth: 1] + ], + splits: [ + text: [fill: "#abb2bf", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#98c379", stroke: "#abb2bf", strokeWidth: 1], + children: [fill: "#abb2bf", stroke: "#abb2bf", strokeWidth: 1] + ], + yes: [ + text: [fill: "#98c379"], + path: [stroke: "#abb2bf", strokeWidth: 1] + ], + no: [ + text: [fill: "#e06c75"], + path: [stroke: "#abb2bf", strokeWidth: 1] + ] + ] + + def gruvbox(), + do: [ + background: "#282828", + leaves: [ + text: [fill: "#ebdbb2", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#3c3836", stroke: "#b8bb26", strokeWidth: 1] + ], + splits: [ + text: [fill: "#ebdbb2", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#b8bb26", stroke: "#ebdbb2", strokeWidth: 1], + children: [fill: "#ebdbb2", stroke: "#ebdbb2", strokeWidth: 1] + ], + yes: [ + text: [fill: "#b8bb26"], + path: [stroke: "#ebdbb2", strokeWidth: 1] + ], + no: [ + text: [fill: "#fb4934"], + path: [stroke: "#ebdbb2", strokeWidth: 1] + ] + ] +end diff --git a/mix.exs b/mix.exs index c2a68b4..0b27e5d 100644 --- a/mix.exs +++ b/mix.exs @@ -1,6 +1,6 @@ defmodule EXGBoost.MixProject do use Mix.Project - @version "0.4.0" + @version "0.5.0" def project do [ From 9610d511290fe10a67aaef1c4cd438143fec8ab6 Mon Sep 17 00:00:00 2001 From: acalejos Date: Fri, 19 Jan 2024 22:15:18 -0500 Subject: [PATCH 09/21] Update style opts --- lib/exgboost/plotting.ex | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex index 87f6f98..e89e032 100644 --- a/lib/exgboost/plotting.ex +++ b/lib/exgboost/plotting.ex @@ -153,7 +153,9 @@ defmodule EXGBoost.Plotting do :high_contrast, :monokai, :material, - :one_dark + :one_dark, + nil, + false ]} ], rankdir: [ @@ -385,11 +387,15 @@ defmodule EXGBoost.Plotting do def deep_merge_kw(a, b) do Keyword.merge(a, b, fn - _key, val_a, val_b when is_list(val_a) and is_list(val_b) -> + key, val_a, val_b when is_list(val_a) and is_list(val_b) -> deep_merge_kw(val_a, val_b) - _key, _val_a, val_b -> - val_b + key, val_a, val_b -> + if Keyword.has_key?(b, key) do + val_b + else + val_a + end end) end @@ -457,11 +463,7 @@ defmodule EXGBoost.Plotting do opts = unless opts[:style] in [nil, false] do - style = - NimbleOptions.validate!( - apply(EXGBoost.Plotting.Styles, opts[:style], []), - @plotting_schema - ) + style = apply(EXGBoost.Plotting.Styles, opts[:style], []) deep_merge_kw(opts, style) else @@ -785,11 +787,7 @@ defmodule EXGBoost.Plotting do opts = unless opts[:style] in [nil, false] do - style = - NimbleOptions.validate!( - apply(EXGBoost.Plotting.Styles, opts[:style], []), - @plotting_schema - ) + style = apply(EXGBoost.Plotting.Styles, opts[:style], []) deep_merge_kw(opts, style) else From f3c5aae972c1d120b0834ff9d395e73313c4b183 Mon Sep 17 00:00:00 2001 From: acalejos Date: Fri, 19 Jan 2024 23:14:16 -0500 Subject: [PATCH 10/21] Merge with main - fix errors --- lib/exgboost/plotting.ex | 2 +- lib/exgboost/training.ex | 29 ++++++++++++++++++++++++----- test/exgboost_test.exs | 24 ++++++++++++++---------- 3 files changed, 39 insertions(+), 16 deletions(-) diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex index e89e032..c25b1f3 100644 --- a/lib/exgboost/plotting.ex +++ b/lib/exgboost/plotting.ex @@ -387,7 +387,7 @@ defmodule EXGBoost.Plotting do def deep_merge_kw(a, b) do Keyword.merge(a, b, fn - key, val_a, val_b when is_list(val_a) and is_list(val_b) -> + _key, val_a, val_b when is_list(val_a) and is_list(val_b) -> deep_merge_kw(val_a, val_b) key, val_a, val_b -> diff --git a/lib/exgboost/training.ex b/lib/exgboost/training.ex index 7bf942f..8c27b44 100644 --- a/lib/exgboost/training.ex +++ b/lib/exgboost/training.ex @@ -13,13 +13,15 @@ defmodule EXGBoost.Training do learning_rates: nil, num_boost_rounds: 10, obj: nil, - verbose_eval: true + verbose_eval: true, + disable_default_eval_metric: false ] {opts, booster_params} = Keyword.split(opts, Keyword.keys(valid_opts)) [ callbacks: callbacks, + disable_default_eval_metric: disable_default_eval_metric, early_stopping_rounds: early_stopping_rounds, evals: evals, learning_rates: learning_rates, @@ -55,7 +57,14 @@ defmodule EXGBoost.Training do ) defaults = - default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds) + default_callbacks( + bst, + learning_rates, + verbose_eval, + evals_dmats, + early_stopping_rounds, + disable_default_eval_metric + ) callbacks = Enum.map(callbacks ++ defaults, fn %Callback{fun: fun} = callback -> @@ -119,7 +128,14 @@ defmodule EXGBoost.Training do %{state | iteration: iter} end - defp default_callbacks(bst, learning_rates, verbose_eval, evals_dmats, early_stopping_rounds) do + defp default_callbacks( + bst, + learning_rates, + verbose_eval, + evals_dmats, + early_stopping_rounds, + disable_default_eval_metric + ) do default_callbacks = [] default_callbacks = @@ -154,12 +170,15 @@ defmodule EXGBoost.Training do if early_stopping_rounds && evals_dmats != [] do [{_dmat, target_eval} | _tail] = Enum.reverse(evals_dmats) + # This is still somewhat hacky and relies on a modification made to + # XGBoost in the Makefile to dump the config to JSON. + # %{"learner" => %{"metrics" => metrics, "default_metric" => default_metric}} = - EXGBoost.dump_config(bst) |> Jason.decode!() + EXGBoost.dump_config(bst) |> Jason.decode!() |> dbg() metric_name = cond do - Enum.empty?(metrics) && opts[:disable_default_eval_metric] -> + Enum.empty?(metrics) && disable_default_eval_metric -> raise ArgumentError, "`:early_stopping_rounds` requires at least one evaluation set. This means you have likely set `disable_default_eval_metric: true` and have not set any explicit evalutation metrics. Please supply at least one metric in the `:eval_metric` option or set `disable_default_eval_metric: false` (default option)" diff --git a/test/exgboost_test.exs b/test/exgboost_test.exs index e239ff6..cf96554 100644 --- a/test/exgboost_test.exs +++ b/test/exgboost_test.exs @@ -152,16 +152,20 @@ defmodule EXGBoostTest do refute is_nil(booster.best_iteration) refute is_nil(booster.best_score) - {booster, _} = - ExUnit.CaptureIO.with_io(fn -> - EXGBoost.train(x, y, - disable_default_eval_metric: true, - num_boost_rounds: 10, - early_stopping_rounds: 1, - evals: [{x, y, "validation"}], - tree_method: :hist - ) - end) + # If no eval metric is provided, the default metric is used. If the default + # metric is disabled, an error is raised. + assert_raise ArgumentError, + fn -> + ExUnit.CaptureIO.with_io(fn -> + EXGBoost.train(x, y, + disable_default_eval_metric: true, + num_boost_rounds: 10, + early_stopping_rounds: 1, + evals: [{x, y, "validation"}], + tree_method: :hist + ) + end) + end refute is_nil(booster.best_iteration) refute is_nil(booster.best_score) From 4c5672006172475266b665ea7c2a9e9f12018b1e Mon Sep 17 00:00:00 2001 From: acalejos Date: Fri, 19 Jan 2024 23:35:58 -0500 Subject: [PATCH 11/21] set feature names in dmatrices --- lib/exgboost/training.ex | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/exgboost/training.ex b/lib/exgboost/training.ex index 8c27b44..3c58715 100644 --- a/lib/exgboost/training.ex +++ b/lib/exgboost/training.ex @@ -6,6 +6,8 @@ defmodule EXGBoost.Training do @spec train(DMatrix.t(), Keyword.t()) :: Booster.t() def train(%DMatrix{} = dmat, opts \\ []) do + dmat_opts = Keyword.take(opts, EXGBoost.Internal.dmatrix_feature_opts()) + valid_opts = [ callbacks: [], early_stopping_rounds: nil, @@ -47,7 +49,7 @@ defmodule EXGBoost.Training do evals_dmats = Enum.map(evals, fn {x, y, name} -> - {DMatrix.from_tensor(x, y, format: :dense), name} + {DMatrix.from_tensor(x, y, Keyword.put_new(dmat_opts, :format, :dense)), name} end) bst = @@ -174,7 +176,7 @@ defmodule EXGBoost.Training do # XGBoost in the Makefile to dump the config to JSON. # %{"learner" => %{"metrics" => metrics, "default_metric" => default_metric}} = - EXGBoost.dump_config(bst) |> Jason.decode!() |> dbg() + EXGBoost.dump_config(bst) |> Jason.decode!() metric_name = cond do From c254c85b357224b949bdeac706d533a623d374c5 Mon Sep 17 00:00:00 2001 From: acalejos Date: Sat, 20 Jan 2024 23:58:26 -0500 Subject: [PATCH 12/21] Fix plotting defaut centering --- lib/exgboost/plotting.ex | 8 +++++--- lib/exgboost/training/callback.ex | 3 +++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex index c25b1f3..e38679a 100644 --- a/lib/exgboost/plotting.ex +++ b/lib/exgboost/plotting.ex @@ -199,7 +199,7 @@ defmodule EXGBoost.Plotting do bottom: [type: :pos_integer] ] ]}, - default: 5 + default: 30 ], leaves: [ doc: "Specifies characteristics of leaf nodes", @@ -631,6 +631,8 @@ defmodule EXGBoost.Plotting do "expr" => "#{if opts[:rankdir] == :rl, do: -1, else: 1}*(datum.x+(width/2))", "type" => "formula" }, + %{"type" => "extent", "field" => "x", "signal" => "x_extent"}, + %{"type" => "extent", "field" => "y", "signal" => "y_extent"}, %{"as" => "xscaled", "expr" => "scale('xscale',datum.x)", "type" => "formula"}, %{"as" => "parent", "expr" => "datum.parentid", "type" => "formula"} ] @@ -1168,7 +1170,7 @@ defmodule EXGBoost.Plotting do "[anchor[0] + (xdom[0] - anchor[0]) * zoom, anchor[0] + (xdom[1] - anchor[0]) * zoom]" } ], - "update" => "slice(xext)" + "update" => "[x_extent[0] - nodeWidth/ 2, x_extent[1] + nodeWidth / 2]" }, %{ "name" => "ydom", @@ -1184,7 +1186,7 @@ defmodule EXGBoost.Plotting do "[anchor[1] + (ydom[0] - anchor[1]) * zoom, anchor[1] + (ydom[1] - anchor[1]) * zoom]" } ], - "update" => "slice(yext)" + "update" => "[y_extent[0] - nodeHeight, y_extent[1] + nodeHeight/3]" }, %{"name" => "scaledNodeWidth", "update" => "(nodeWidth/ span(xdom))*width"}, %{"name" => "scaledNodeHeight", "update" => "abs(nodeHeight/ span(ydom))*height"}, diff --git a/lib/exgboost/training/callback.ex b/lib/exgboost/training/callback.ex index dadd0c0..8f873db 100644 --- a/lib/exgboost/training/callback.ex +++ b/lib/exgboost/training/callback.ex @@ -166,6 +166,9 @@ defmodule EXGBoost.Training.Callback do true -> early_stop = Map.update!(early_stop, :since_last_improvement, &(&1 + 1)) + # TODO: Should this actually update the best iteration and score? + # This iteration is not the best, but it is the last one, so do we want + # another way to track last iteration? bst = struct(bst, best_iteration: state.iteration, best_score: score) %{state | booster: bst, meta_vars: %{meta_vars | early_stop: early_stop}, status: :halt} end From e8e8b5704a2da6398093918a22ff42dca432264b Mon Sep 17 00:00:00 2001 From: acalejos Date: Mon, 22 Jan 2024 00:05:45 -0500 Subject: [PATCH 13/21] Add Style gallery in docs --- lib/exgboost/plotting.ex | 529 ++++++++++++++++++++++++++++---- lib/exgboost/plotting/style.ex | 16 + lib/exgboost/plotting/styles.ex | 303 +----------------- mix.exs | 32 ++ test/data/model.json | Bin 0 -> 31162 bytes 5 files changed, 526 insertions(+), 354 deletions(-) create mode 100644 lib/exgboost/plotting/style.ex create mode 100644 test/data/model.json diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex index e38679a..05a147a 100644 --- a/lib/exgboost/plotting.ex +++ b/lib/exgboost/plotting.ex @@ -1,60 +1,397 @@ defmodule EXGBoost.Plotting do - @moduledoc """ - Functions for plotting EXGBoost `Booster` models using [Vega](https://vega.github.io/vega/) + use EXGBoost.Plotting.Style - Fundamentally, all this module does is convert a `Booster` into a format that can be - ingested by Vega, and apply some default configuations that only account for a subset of the configurations - that can be set by a Vega spec directly. The functions provided in this module are designed to have opinionated - defaults that can be used to quickly visualize a model, but the full power of Vega is available by using the - `to_tabular/1` function to convert the model into a tabular format, and then using the `to_vega/2` function - to convert the tabular format into a Vega specification. + @doc """ + A light theme based on the [Solarized](https://ethanschoonover.com/solarized/) color palette + """ + style :solarized_light do + [ + # base3 + background: "#fdf6e3", + leaves: [ + # base01 + text: [fill: "#586e75", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + # base2, base1 + rect: [fill: "#eee8d5", stroke: "#93a1a1", strokeWidth: 1] + ], + splits: [ + # base01 + text: [fill: "#586e75", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + # base1, base00 + rect: [fill: "#93a1a1", stroke: "#657b83", strokeWidth: 1], + # base00 + children: [fill: "#657b83", stroke: "#657b83", strokeWidth: 1] + ], + yes: [ + # green + text: [fill: "#859900"], + # base00 + path: [stroke: "#657b83", strokeWidth: 1] + ], + no: [ + # red + text: [fill: "#dc322f"], + # base00 + path: [stroke: "#657b83", strokeWidth: 1] + ] + ] + end - ## Default Vega Specification + @doc """ + A dark theme based on the [Solarized](https://ethanschoonover.com/solarized/) color palette + """ + style :solarized_dark do + # base03 + [ + background: "#002b36", + leaves: [ + # base0 + text: [fill: "#839496", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + # base02, base01 + rect: [fill: "#073642", stroke: "#586e75", strokeWidth: 1] + ], + splits: [ + # base0 + text: [fill: "#839496", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + # base01, base00 + rect: [fill: "#586e75", stroke: "#657b83", strokeWidth: 1], + # base00 + children: [fill: "#657b83", stroke: "#657b83", strokeWidth: 1] + ], + yes: [ + # green + text: [fill: "#859900"], + # base00 + path: [stroke: "#657b83", strokeWidth: 1] + ], + no: [ + # red + text: [fill: "#dc322f"], + # base00 + path: [stroke: "#657b83", strokeWidth: 1] + ] + ] + end - The default Vega specification is designed to be a good starting point for visualizing a model, but it is - possible to customize the specification by passing in a map of Vega properties to the `to_vega/2` function. - Refer to `Custom Vega Specifications` for more details on how to do this. + @doc """ + A light and playful theme + """ + style :playful_light do + [ + background: "#f0f0f0", + padding: 10, + leaves: [ + text: [fill: "#000", font_size: 12, font_style: "italic", font_weight: "bold"], + rect: [fill: "#e91e63", stroke: "#000", stroke_width: 1, radius: 5] + ], + splits: [ + text: [fill: "#000", font_size: 12, font_style: "normal", font_weight: "bold"], + children: [ + fill: "#000", + font_size: 12, + font_style: "normal", + font_weight: "bold" + ], + rect: [fill: "#8bc34a", stroke: "#000", stroke_width: 1, radius: 10] + ], + yes: [ + path: [stroke: "#4caf50", stroke_width: 2] + ], + no: [ + path: [stroke: "#f44336", stroke_width: 2] + ] + ] + end - By default, the Vega specification includes the following entities to use for rendering the model: - * `:width` - The width of the plot in pixels - * `:height` - The height of the plot in pixels - * `:padding` - The padding in pixels to add around the visualization. If a number, specifies padding for all sides. If an object, the value should have the format `[left: value, right: value, top: value, bottom: value]` - * `:leafs` - Specifies characteristics of leaf nodes - * `:inner_nodes` - Specifies characteristics of inner nodes - * `:links` - Specifies characteristics of links between nodes + @doc """ + A dark and playful theme + """ + style :playful_dark do + [ + background: "#333", + padding: 10, + leaves: [ + text: [fill: "#fff", font_size: 12, font_style: "italic", font_weight: "bold"], + rect: [fill: "#e91e63", stroke: "#fff", stroke_width: 1, radius: 5] + ], + splits: [ + text: [fill: "#fff", font_size: 12, font_style: "normal", font_weight: "bold"], + rect: [fill: "#8bc34a", stroke: "#fff", stroke_width: 1, radius: 10] + ], + yes: [ + text: [fill: "#4caf50"], + path: [stroke: "#4caf50", stroke_width: 2] + ], + no: [ + text: [fill: "#f44336"], + path: [stroke: "#f44336", stroke_width: 2] + ] + ] + end - ## Custom Vega Specifications + @doc """ + A dark theme + """ + style :dark do + [ + background: "#333", + padding: 10, + leaves: [ + text: [fill: "#fff", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#666", stroke: "#fff", strokeWidth: 1] + ], + splits: [ + text: [fill: "#fff", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#444", stroke: "#fff", strokeWidth: 1], + children: [fill: "#fff", stroke: "#fff", strokeWidth: 1] + ] + ] + end - The default Vega specification is designed to be a good starting point for visualizing a model, but it is - possible to customize the specification by passing in a map of Vega properties to the `to_vega/2` function. - You can find the full list of Vega properties [here](https://vega.github.io/vega/docs/specification/). + @doc """ + A high contrast theme + """ + style :high_contrast do + [ + background: "#000", + padding: 10, + leaves: [ + text: [fill: "#fff", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#333", stroke: "#fff", strokeWidth: 1] + ], + splits: [ + text: [fill: "#fff", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#666", stroke: "#fff", strokeWidth: 1] + ] + ] + end - It is suggested that you use the data attributes provided by the default specification as a starting point, since they - provide the necessary data transformation to convert the model into a tree structure that can be visualized by Vega. - If you would like to customize the default specification, you can use `EXGBoost.Plotting.to_vega/1` to get the default - specification, and then modify it as needed. + @doc """ + A light theme + """ + style :light do + [ + background: "#f0f0f0", + padding: 10, + leaves: [ + text: [fill: "#000", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#ddd", stroke: "#000", strokeWidth: 1] + ], + splits: [ + text: [fill: "#000", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#bbb", stroke: "#000", strokeWidth: 1] + ] + ] + end - Once you have a custom specification, you can pass it to `VegaLite.from_json/1` to create a new `VegaLite` struct, after which - you can use the functions provided by the `VegaLite` module to render the model. + @doc """ + A theme based on the [Monokai](https://monokai.pro/) color palette + """ + style :monokai do + [ + background: "#272822", + leaves: [ + text: [fill: "#f8f8f2", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#3e3d32", stroke: "#66d9ef", strokeWidth: 1] + ], + splits: [ + text: [fill: "#f8f8f2", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#66d9ef", stroke: "#f8f8f2", strokeWidth: 1], + children: [fill: "#f8f8f2", stroke: "#f8f8f2", strokeWidth: 1] + ], + yes: [ + text: [fill: "#a6e22e"], + path: [stroke: "#f8f8f2", strokeWidth: 1] + ], + no: [ + text: [fill: "#f92672"], + path: [stroke: "#f8f8f2", strokeWidth: 1] + ] + ] + end - ## Specification Validation - You can optionally validate your specification against the Vega schema by passing the `validate: true` option to `to_vega/2`. - This will raise an error if the specification is invalid. This is useful if you are creating a custom specification and want - to ensure that it is valid. Note that this will only validate the specification against the Vega schema, and not against the - VegaLite schema. This requires the [`ex_json_schema`] package to be installed. + @doc """ + A theme based on the [Dracula](https://draculatheme.com/) color palette + """ + style :dracula do + [ + background: "#282a36", + leaves: [ + text: [fill: "#f8f8f2", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#44475a", stroke: "#ff79c6", strokeWidth: 1] + ], + splits: [ + text: [fill: "#f8f8f2", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#ff79c6", stroke: "#f8f8f2", strokeWidth: 1], + children: [fill: "#f8f8f2", stroke: "#f8f8f2", strokeWidth: 1] + ], + yes: [ + text: [fill: "#50fa7b"], + path: [stroke: "#f8f8f2", strokeWidth: 1] + ], + no: [ + text: [fill: "#ff5555"], + path: [stroke: "#f8f8f2", strokeWidth: 1] + ] + ] + end - ## Livebook Integration + @doc """ + A theme based on the [Nord](https://www.nordtheme.com/) color palette + """ + style :nord do + [ + background: "#2e3440", + leaves: [ + text: [fill: "#d8dee9", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#3b4252", stroke: "#88c0d0", strokeWidth: 1] + ], + splits: [ + text: [fill: "#d8dee9", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#88c0d0", stroke: "#d8dee9", strokeWidth: 1], + children: [fill: "#d8dee9", stroke: "#d8dee9", strokeWidth: 1] + ], + yes: [ + text: [fill: "#a3be8c"], + path: [stroke: "#d8dee9", strokeWidth: 1] + ], + no: [ + text: [fill: "#bf616a"], + path: [stroke: "#d8dee9", strokeWidth: 1] + ] + ] + end - This module also provides a `Kino.Render` implementation for `EXGBoost.Booster` which allows - models to be rendered directly in Livebook. This is done by converting the model into a Vega specification - and then using the `Kino.Render` implementation for Elixir's [`VegaLite`](https://hexdocs.pm/vega_lite/VegaLite.html) API - to render the model. + @doc """ + A theme based on the [Material](https://material.io/design/color/the-color-system.html#tools-for-picking-colors) color palette + """ + style :material do + [ + background: "#263238", + leaves: [ + text: [fill: "#eceff1", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#37474f", stroke: "#80cbc4", strokeWidth: 1] + ], + splits: [ + text: [fill: "#eceff1", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#80cbc4", stroke: "#eceff1", strokeWidth: 1], + children: [fill: "#eceff1", stroke: "#eceff1", strokeWidth: 1] + ], + yes: [ + text: [fill: "#c5e1a5"], + path: [stroke: "#eceff1", strokeWidth: 1] + ], + no: [ + text: [fill: "#ef9a9a"], + path: [stroke: "#eceff1", strokeWidth: 1] + ] + ] + end + @doc """ + A theme based on the One Dark color palette + """ + style :one_dark do + [ + background: "#282c34", + leaves: [ + text: [fill: "#abb2bf", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#3b4048", stroke: "#98c379", strokeWidth: 1] + ], + splits: [ + text: [fill: "#abb2bf", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#98c379", stroke: "#abb2bf", strokeWidth: 1], + children: [fill: "#abb2bf", stroke: "#abb2bf", strokeWidth: 1] + ], + yes: [ + text: [fill: "#98c379"], + path: [stroke: "#abb2bf", strokeWidth: 1] + ], + no: [ + text: [fill: "#e06c75"], + path: [stroke: "#abb2bf", strokeWidth: 1] + ] + ] + end - . The Vega specification is then passed to [VegaLite](https://hexdocs.pm/vega_lite/readme.html) + @doc """ + A theme based on the [Gruvbox](https://github.com/morhetz/gruvbox) color palette + """ + style :gruvbox do + [ + background: "#282828", + leaves: [ + text: [fill: "#ebdbb2", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#3c3836", stroke: "#b8bb26", strokeWidth: 1] + ], + splits: [ + text: [fill: "#ebdbb2", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#b8bb26", stroke: "#ebdbb2", strokeWidth: 1], + children: [fill: "#ebdbb2", stroke: "#ebdbb2", strokeWidth: 1] + ], + yes: [ + text: [fill: "#b8bb26"], + path: [stroke: "#ebdbb2", strokeWidth: 1] + ], + no: [ + text: [fill: "#fb4934"], + path: [stroke: "#ebdbb2", strokeWidth: 1] + ] + ] + end + @doc """ + A dark theme based on the [Horizon](https://www.horizon.io/) color palette + """ + style :horizon_dark do + [ + background: "#1C1E26", + leaves: [ + text: [fill: "#E3E6EE", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#232530", stroke: "#F43E5C", strokeWidth: 1] + ], + splits: [ + text: [fill: "#E3E6EE", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#F43E5C", stroke: "#E3E6EE", strokeWidth: 1], + children: [fill: "#E3E6EE", stroke: "#E3E6EE", strokeWidth: 1] + ], + yes: [ + text: [fill: "#48B685"], + path: [stroke: "#E3E6EE", strokeWidth: 1] + ], + no: [ + text: [fill: "#F43E5C"], + path: [stroke: "#E3E6EE", strokeWidth: 1] + ] + ] + end + @doc """ + A light theme based on the [Horizon](https://www.horizon.io/) color palette """ + style :horizon_light do + [ + background: "#FDF0ED", + leaves: [ + text: [fill: "#1A2026", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], + rect: [fill: "#F7E3D3", stroke: "#F43E5C", strokeWidth: 1] + ], + splits: [ + text: [fill: "#1A2026", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], + rect: [fill: "#F43E5C", stroke: "#1A2026", strokeWidth: 1], + children: [fill: "#1A2026", stroke: "#1A2026", strokeWidth: 1] + ], + yes: [ + text: [fill: "#48B685"], + path: [stroke: "#1A2026", strokeWidth: 1] + ], + no: [ + text: [fill: "#F43E5C"], + path: [stroke: "#1A2026", strokeWidth: 1] + ] + ] + end + HTTPoison.start() @schema HTTPoison.get!("https://vega.github.io/schema/vega/v5.json").body @@ -136,27 +473,10 @@ defmodule EXGBoost.Plotting do @plotting_params [ style: [ - doc: "The style to use for the visualization.", + doc: + "The style to use for the visualization. Refer to `EXGBoost.Plotting.Styles` for a list of available styles.", default: :solarized_light, - type: - {:in, - [ - :solarized_light, - :solarized_dark, - :playful_light, - :playful_dark, - :dark, - :light, - :nord, - :dracula, - :gruvbox, - :high_contrast, - :monokai, - :material, - :one_dark, - nil, - false - ]} + type: {:in, Keyword.keys(@styles)} ], rankdir: [ doc: "Determines the direction of the graph.", @@ -343,13 +663,89 @@ defmodule EXGBoost.Plotting do ] @plotting_schema NimbleOptions.new!(@plotting_params) - @defaults NimbleOptions.validate!([], @plotting_schema) + @moduledoc """ + Functions for plotting EXGBoost `Booster` models using [Vega](https://vega.github.io/vega/) + + Fundamentally, all this module does is convert a `Booster` into a format that can be + ingested by Vega, and apply some default configuations that only account for a subset of the configurations + that can be set by a Vega spec directly. The functions provided in this module are designed to have opinionated + defaults that can be used to quickly visualize a model, but the full power of Vega is available by using the + `to_tabular/1` function to convert the model into a tabular format, and then using the `to_vega/2` function + to convert the tabular format into a Vega specification. + + ## Default Vega Specification + + The default Vega specification is designed to be a good starting point for visualizing a model, but it is + possible to customize the specification by passing in a map of Vega properties to the `to_vega/2` function. + Refer to `Custom Vega Specifications` for more details on how to do this. + + By default, the Vega specification includes the following entities to use for rendering the model: + * `:width` - The width of the plot in pixels + * `:height` - The height of the plot in pixels + * `:padding` - The padding in pixels to add around the visualization. If a number, specifies padding for all sides. If an object, the value should have the format `[left: value, right: value, top: value, bottom: value]` + * `:leafs` - Specifies characteristics of leaf nodes + * `:inner_nodes` - Specifies characteristics of inner nodes + * `:links` - Specifies characteristics of links between nodes + + ## Custom Vega Specifications + + The default Vega specification is designed to be a good starting point for visualizing a model, but it is + possible to customize the specification by passing in a map of Vega properties to the `to_vega/2` function. + You can find the full list of Vega properties [here](https://vega.github.io/vega/docs/specification/). + + It is suggested that you use the data attributes provided by the default specification as a starting point, since they + provide the necessary data transformation to convert the model into a tree structure that can be visualized by Vega. + If you would like to customize the default specification, you can use `EXGBoost.Plotting.to_vega/1` to get the default + specification, and then modify it as needed. + + Once you have a custom specification, you can pass it to `VegaLite.from_json/1` to create a new `VegaLite` struct, after which + you can use the functions provided by the `VegaLite` module to render the model. + + ## Specification Validation + You can optionally validate your specification against the Vega schema by passing the `validate: true` option to `to_vega/2`. + This will raise an error if the specification is invalid. This is useful if you are creating a custom specification and want + to ensure that it is valid. Note that this will only validate the specification against the Vega schema, and not against the + VegaLite schema. This requires the [`ex_json_schema`] package to be installed. + + ## Livebook Integration + + This module also provides a `Kino.Render` implementation for `EXGBoost.Booster` which allows + models to be rendered directly in Livebook. This is done by converting the model into a Vega specification + and then using the `Kino.Render` implementation for Elixir's [`VegaLite`](https://hexdocs.pm/vega_lite/VegaLite.html) API + to render the model. + + . The Vega specification is then passed to [VegaLite](https://hexdocs.pm/vega_lite/readme.html) + + ## Plotting Parameters + + This module exposes a high-level API for customizing the EXGBoost model visualization, but it is also possible to + customize the Vega specification directly. You can also choose to pass in Vega Mark specifications to customize the + appearance of the nodes and links in the visualization outside of the parameteres specified below. Refer to the + [Vega documentation](https://vega.github.io/vega/docs/marks/) for more details on how to do this. + + #{NimbleOptions.docs(@plotting_params)} + + + ## Styles + + Styles are a keyword-map that adhere to the plotting schema as defined in `EXGBoost.Plotting`. + `EXGBoost.Plotting.Styles` provides a set of predefined styles that can be used to quickly customize the appearance of the visualization. + + + Refer to the `EXGBoost.Plotting.Styles` module for a list of available styles. You can pass a style to the `:style` + option as an atom or string, and it will be applied to the visualization. Styles will override any other options that are passed + for each option where the style defined a value. For example, if you pass `:solarized_light` as the style, and also pass + `:background` as an option, the `:background` option will be ignored since the `:solarized_light` style defines its own value for `:background`. + """ + def get_schema(), do: @schema def get_defaults(), do: @defaults + def get_styles(), do: @styles + defp validate_spec(spec) do case ExJsonSchema.Validator.validate(get_schema(), spec) do :ok -> @@ -385,7 +781,7 @@ defmodule EXGBoost.Plotting do end end - def deep_merge_kw(a, b) do + defp deep_merge_kw(a, b) do Keyword.merge(a, b, fn _key, val_a, val_b when is_list(val_a) and is_list(val_b) -> deep_merge_kw(val_a, val_b) @@ -463,7 +859,7 @@ defmodule EXGBoost.Plotting do opts = unless opts[:style] in [nil, false] do - style = apply(EXGBoost.Plotting.Styles, opts[:style], []) + style = apply(__MODULE__, opts[:style], []) deep_merge_kw(opts, style) else @@ -789,7 +1185,7 @@ defmodule EXGBoost.Plotting do opts = unless opts[:style] in [nil, false] do - style = apply(EXGBoost.Plotting.Styles, opts[:style], []) + style = apply(__MODULE__, opts[:style], []) deep_merge_kw(opts, style) else @@ -1209,7 +1605,6 @@ defmodule EXGBoost.Plotting do }) spec = Map.merge(spec, tlk) |> Map.delete("style") - File.write!("/Users/andres/Documents/exgboost/spec.json", Jason.encode!(spec)) spec = if opts[:validate], do: validate_spec(spec), else: spec Jason.encode!(spec) |> VegaLite.from_json() @@ -1270,6 +1665,10 @@ defmodule EXGBoost.Plotting do defp capitalize(<>) when first in ?a..?z, do: <> defp capitalize(rest), do: rest + + def __after_compile__(env) do + IO.inspect(env) + end end defimpl Kino.Render, for: EXGBoost.Booster do diff --git a/lib/exgboost/plotting/style.ex b/lib/exgboost/plotting/style.ex new file mode 100644 index 0000000..002a62b --- /dev/null +++ b/lib/exgboost/plotting/style.ex @@ -0,0 +1,16 @@ +defmodule EXGBoost.Plotting.Style do + @moduledoc false + defmacro __using__(_opts) do + quote do + import EXGBoost.Plotting.Style + Module.register_attribute(__MODULE__, :styles, accumulate: true) + end + end + + defmacro style(style_name, do: body) do + quote do + def unquote(style_name)(), do: unquote(body) + Module.put_attribute(__MODULE__, :styles, {unquote(style_name), unquote(body)}) + end + end +end diff --git a/lib/exgboost/plotting/styles.ex b/lib/exgboost/plotting/styles.ex index 62d4807..cd6fdf8 100644 --- a/lib/exgboost/plotting/styles.ex +++ b/lib/exgboost/plotting/styles.ex @@ -1,293 +1,18 @@ defmodule EXGBoost.Plotting.Styles do + @bst File.cwd!() |> Path.join("test/data/model.json") |> EXGBoost.read_model() + @moduledoc """ - A style is a keyword-map that adheres to the plotting schema - as defined in `EXGBoost.Plotting`. +
+ #{Enum.map(EXGBoost.Plotting.get_styles(), fn {name, _style} -> """ +
+

#{name}

+
+        
+        #{EXGBoost.Plotting.to_vega(@bst, style: name, height: 200, width: 300).spec |> Jason.encode!()}
+        
+      
+
+ """ end) |> Enum.join("\n\n")} +
""" - - def solarized_light(), - do: [ - # base3 - background: "#fdf6e3", - leaves: [ - # base01 - text: [fill: "#586e75", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], - # base2, base1 - rect: [fill: "#eee8d5", stroke: "#93a1a1", strokeWidth: 1] - ], - splits: [ - # base01 - text: [fill: "#586e75", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], - # base1, base00 - rect: [fill: "#93a1a1", stroke: "#657b83", strokeWidth: 1], - # base00 - children: [fill: "#657b83", stroke: "#657b83", strokeWidth: 1] - ], - yes: [ - # green - text: [fill: "#859900"], - # base00 - path: [stroke: "#657b83", strokeWidth: 1] - ], - no: [ - # red - text: [fill: "#dc322f"], - # base00 - path: [stroke: "#657b83", strokeWidth: 1] - ] - ] - - def solarized_dark(), - do: [ - # base03 - background: "#002b36", - leaves: [ - # base0 - text: [fill: "#839496", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], - # base02, base01 - rect: [fill: "#073642", stroke: "#586e75", strokeWidth: 1] - ], - splits: [ - # base0 - text: [fill: "#839496", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], - # base01, base00 - rect: [fill: "#586e75", stroke: "#657b83", strokeWidth: 1], - # base00 - children: [fill: "#657b83", stroke: "#657b83", strokeWidth: 1] - ], - yes: [ - # green - text: [fill: "#859900"], - # base00 - path: [stroke: "#657b83", strokeWidth: 1] - ], - no: [ - # red - text: [fill: "#dc322f"], - # base00 - path: [stroke: "#657b83", strokeWidth: 1] - ] - ] - - def playful_light(), - do: [ - background: "#f0f0f0", - padding: 10, - leaves: [ - text: [fill: "#000", font_size: 12, font_style: "italic", font_weight: "bold"], - rect: [fill: "#e91e63", stroke: "#000", stroke_width: 1, radius: 5] - ], - splits: [ - text: [fill: "#000", font_size: 12, font_style: "normal", font_weight: "bold"], - children: [ - fill: "#000", - font_size: 12, - font_style: "normal", - font_weight: "bold" - ], - rect: [fill: "#8bc34a", stroke: "#000", stroke_width: 1, radius: 10] - ], - yes: [ - path: [stroke: "#4caf50", stroke_width: 2] - ], - no: [ - path: [stroke: "#f44336", stroke_width: 2] - ] - ] - - def playful_dark(), - do: [ - background: "#333", - padding: 10, - leaves: [ - text: [fill: "#fff", font_size: 12, font_style: "italic", font_weight: "bold"], - rect: [fill: "#e91e63", stroke: "#fff", stroke_width: 1, radius: 5] - ], - splits: [ - text: [fill: "#fff", font_size: 12, font_style: "normal", font_weight: "bold"], - rect: [fill: "#8bc34a", stroke: "#fff", stroke_width: 1, radius: 10] - ], - yes: [ - text: [fill: "#4caf50"], - path: [stroke: "#4caf50", stroke_width: 2] - ], - no: [ - text: [fill: "#f44336"], - path: [stroke: "#f44336", stroke_width: 2] - ] - ] - - def dark(), - do: [ - background: "#333", - padding: 10, - leaves: [ - text: [fill: "#fff", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], - rect: [fill: "#666", stroke: "#fff", strokeWidth: 1] - ], - splits: [ - text: [fill: "#fff", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], - rect: [fill: "#444", stroke: "#fff", strokeWidth: 1], - children: [fill: "#fff", stroke: "#fff", strokeWidth: 1] - ] - ] - - def high_contrast(), - do: [ - background: "#000", - padding: 10, - leaves: [ - text: [fill: "#fff", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], - rect: [fill: "#333", stroke: "#fff", strokeWidth: 1] - ], - splits: [ - text: [fill: "#fff", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], - rect: [fill: "#666", stroke: "#fff", strokeWidth: 1] - ] - ] - - def light(), - do: [ - background: "#f0f0f0", - padding: 10, - leaves: [ - text: [fill: "#000", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], - rect: [fill: "#ddd", stroke: "#000", strokeWidth: 1] - ], - splits: [ - text: [fill: "#000", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], - rect: [fill: "#bbb", stroke: "#000", strokeWidth: 1] - ] - ] - - def monokai(), - do: [ - background: "#272822", - leaves: [ - text: [fill: "#f8f8f2", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], - rect: [fill: "#3e3d32", stroke: "#66d9ef", strokeWidth: 1] - ], - splits: [ - text: [fill: "#f8f8f2", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], - rect: [fill: "#66d9ef", stroke: "#f8f8f2", strokeWidth: 1], - children: [fill: "#f8f8f2", stroke: "#f8f8f2", strokeWidth: 1] - ], - yes: [ - text: [fill: "#a6e22e"], - path: [stroke: "#f8f8f2", strokeWidth: 1] - ], - no: [ - text: [fill: "#f92672"], - path: [stroke: "#f8f8f2", strokeWidth: 1] - ] - ] - - def dracula(), - do: [ - background: "#282a36", - leaves: [ - text: [fill: "#f8f8f2", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], - rect: [fill: "#44475a", stroke: "#ff79c6", strokeWidth: 1] - ], - splits: [ - text: [fill: "#f8f8f2", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], - rect: [fill: "#ff79c6", stroke: "#f8f8f2", strokeWidth: 1], - children: [fill: "#f8f8f2", stroke: "#f8f8f2", strokeWidth: 1] - ], - yes: [ - text: [fill: "#50fa7b"], - path: [stroke: "#f8f8f2", strokeWidth: 1] - ], - no: [ - text: [fill: "#ff5555"], - path: [stroke: "#f8f8f2", strokeWidth: 1] - ] - ] - - def nord(), - do: [ - background: "#2e3440", - leaves: [ - text: [fill: "#d8dee9", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], - rect: [fill: "#3b4252", stroke: "#88c0d0", strokeWidth: 1] - ], - splits: [ - text: [fill: "#d8dee9", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], - rect: [fill: "#88c0d0", stroke: "#d8dee9", strokeWidth: 1], - children: [fill: "#d8dee9", stroke: "#d8dee9", strokeWidth: 1] - ], - yes: [ - text: [fill: "#a3be8c"], - path: [stroke: "#d8dee9", strokeWidth: 1] - ], - no: [ - text: [fill: "#bf616a"], - path: [stroke: "#d8dee9", strokeWidth: 1] - ] - ] - - def material(), - do: [ - background: "#263238", - leaves: [ - text: [fill: "#eceff1", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], - rect: [fill: "#37474f", stroke: "#80cbc4", strokeWidth: 1] - ], - splits: [ - text: [fill: "#eceff1", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], - rect: [fill: "#80cbc4", stroke: "#eceff1", strokeWidth: 1], - children: [fill: "#eceff1", stroke: "#eceff1", strokeWidth: 1] - ], - yes: [ - text: [fill: "#c5e1a5"], - path: [stroke: "#eceff1", strokeWidth: 1] - ], - no: [ - text: [fill: "#ef9a9a"], - path: [stroke: "#eceff1", strokeWidth: 1] - ] - ] - - def one_dark(), - do: [ - background: "#282c34", - leaves: [ - text: [fill: "#abb2bf", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], - rect: [fill: "#3b4048", stroke: "#98c379", strokeWidth: 1] - ], - splits: [ - text: [fill: "#abb2bf", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], - rect: [fill: "#98c379", stroke: "#abb2bf", strokeWidth: 1], - children: [fill: "#abb2bf", stroke: "#abb2bf", strokeWidth: 1] - ], - yes: [ - text: [fill: "#98c379"], - path: [stroke: "#abb2bf", strokeWidth: 1] - ], - no: [ - text: [fill: "#e06c75"], - path: [stroke: "#abb2bf", strokeWidth: 1] - ] - ] - - def gruvbox(), - do: [ - background: "#282828", - leaves: [ - text: [fill: "#ebdbb2", fontSize: 12, fontStyle: "normal", fontWeight: "normal"], - rect: [fill: "#3c3836", stroke: "#b8bb26", strokeWidth: 1] - ], - splits: [ - text: [fill: "#ebdbb2", fontSize: 12, fontStyle: "normal", fontWeight: "bold"], - rect: [fill: "#b8bb26", stroke: "#ebdbb2", strokeWidth: 1], - children: [fill: "#ebdbb2", stroke: "#ebdbb2", strokeWidth: 1] - ], - yes: [ - text: [fill: "#b8bb26"], - path: [stroke: "#ebdbb2", strokeWidth: 1] - ], - no: [ - text: [fill: "#fb4934"], - path: [stroke: "#ebdbb2", strokeWidth: 1] - ] - ] end diff --git a/mix.exs b/mix.exs index 568dc1f..eb55c5b 100644 --- a/mix.exs +++ b/mix.exs @@ -126,6 +126,38 @@ defmodule EXGBoost.MixProject do } }); + + + + + + + """ end diff --git a/test/data/model.json b/test/data/model.json new file mode 100644 index 0000000000000000000000000000000000000000..6f3a37853464eee5baca47559960a4f361dd048f GIT binary patch literal 31162 zcmdsA3wTpS+D@BZNYj+wxNW&Bf++WkP?DK5-Q^~Ag?~XrkdU@%8@5eKl5$azo4U9l za`6YNsN6*4CJI)x?I|lFEW!f1ii&sv1dFl)%OWgA|C!T%lausJ3R?3M=Xp*}PBJs+ z`)1zv``$C>G<`5rS5sns?{elk@*afa*aExF?X=FI0;|hubvP&63LH7s zyrN>OBZt?G=ngewoYsl1ERV)Fu}*Kk+iufp6z{e;3|)4}w_Jhv^!VRw41SuU4{*5X5xoR{Tw z+wImuSB||PVAztJ#f8d81qJjwR?TNZ59r#!DzWc;CywmZou3){_quHkXF&h)MQ&HN z-Q%%(rxp3Sj;)<+IDxUoKD#&HRHyb0{dipfkvf7btNaA-1_PW zrdMzG`V2I7@7k-oQhM|R@1pZ;g@wKn>m~e z;Mok})@+y4qYRYShbA4=j47}cX62|(#Ye=X1Pq5W&q|r%^Hj^+*h1S>YnH=DCVkI7 zec=pA?8|JM*UISZ$#)g_02kS}cimbp8aM3uF4ZLQN(rnhN-3|6Qr}+02f#NzhMF35 zIDB7nOXAp<)T_xVTKSqL7E&%{=Q|2=tW)fcynK}bd`uEp9Zyk#!)s*`(svFX&r%Cr zPM4Q@pgHW3oo)q(ZknX6LgBcDQ^ta2O_fXvp{2HD)fQd)J;dpyjVV%RR+eHB|C->5G~WSH|81 zRz&c)naF|)jeB*YT2hWZt2ocu#Dk~h6qAYOa`QlUxqlK&3TU6{-k}!*o;GMyhHPKGv)nwne)O3^X zapihxPRru1#MTHFj~0Ub(+QU{yshsv9CZFo0!|9%HD^{Zcqy(aQxuSQHRo=+4?nSV z);%;F`>fVGA{{zMB)2X4E+!YaU3COooXtyxT2`^w?s5|(Ztc*k>+R(vtv-i zfWw*V@-=Bqn?xuh{4-vm<7n2gYS3Eq!H8EFUy~?2V{gvk(M)g2>a-VqdTxnqJ@O&+ z6G!u(J(2U4HGiVlL37g8G90wCb{-#KhP8T>`rkWNMh4cc4 zR#zEBjfo62D6XbDBbaCyjUF1j^QkO%*yy;aM`F~dk>u&)3o;rVUjRcAr@-0&YD&qYL=+ifolHZp%P09GdQ6{QKjMuQO?RD@?>qMt z+B*0P{Op8ol^tGu32*As#r&gL|C)Z?0D1JQi_zg%7a{$+bnc@M%j)svlYC<_!U$7N z_uv6zMg$y)nZcSW%%FS5l4!cc(#-%cRqvrYA#;o*Z+tv$ts|GNjFb9AUBYR}-=Gnm zne4N`0=;Q5%Q%CcSJi-lnr+NPwnf!$%-m?`C(`TsH5+@+&l`#e5ZQwZ&tDn>qRW@=s;<0c+-9_AOuTG8`x@#u z;#=-aSW%q;-k;==;E0&ytWgBr^nh@-hh{=DM`!dqeNKAhp|$LcP@t6-X`?A@)TCvn z`gf0w)xZinZ|#yogTGj7}Lz zltzDPIVmG<{d#M*MN*hE0>X#QLm8+5B--s*}s#NFm4oHwPhyByyY}MTS)R%e>mMQ%~yu@Wk}oku^2k2kG-@VW4m$* zr*;G0LqEaUt7a{`YzqdK;Oteimc6f2z@Ua;@1qZ9L~dJ7o?I!D^PNXYEq-W@<<>uu zIh6;{L43n4OqL(g zA-QM&PAJzi19!>Esf^t497bm^nGfztGPk_rTe;iia?7ga7ty3gx^W-FiV<2XQH&X} z;EtFXtZ5zHNk99U3zO0z-DL(o(z9teszMTKvVRBrEUp~MEQaI2f#cCnz-%ozh!}3n=sjQt8MC($xw7aAS$^9Jg7%e?zqaZt zC3^0}YhvFforlTzKc%;k;yEwDjD_TSvjG<-jb(>7ov%z>o`P>}{hoZs;I#w3UpUFK zyZ;g_HJxg?G7D8W`>e6dX}=y#8cMj6zljjg2;8YB1W=4I*hPc7qwh8EPe6A<5o;v# zhc-pkJ0{>+g)?a{Nr$?fXxHjpfJblgaR?Hbj@y275WN%A9VKp(*k|g37WQ2T$Iz=5=S(x`xuGK6Gq|&j!I^D-x9aI} z9NVD7AZd}gUuUJ3wFOsUPd|%?_t0<9^k|YKx^hbas3>Hrj9FS_3ONvfQcAC0fdCY9 zT``y%!wT@8bHCoCPQovwAOwb09C_0!j(xS(Z?st`tqsLE@1^7$SW=H~=le{0l zcBnmGQ@WAFz1k0-c26bwhcofo)ThbF)X-i%NA_sA%jYuVsu7jZi-+KZT~7J7H(ciN z2gX^prp(1Rz4@u-#JhJ@OgQj}#zYv)>6aqTRCl{hGXxQA#6gY@fb z=igPYA}iy@lc)Z5k}MmLi`UTaJeth5_|Otw+SypXe4bfOI~Ip}WXR~sv?BTN@PRmS{2p`K=O34i zuV^Q?Je!Wvr(H3he~)L{@Y2kpq-TQp6*EIgdQ1vZG%!RccRcC-=dKwK@BIjEt&^S` zF$c#@+=4cIw;efp&Sakj7BEhWS%FE9i9)t96WQjckfEa>%2>%MwQeelMzjyj#UfS* z`U|RvYaV8$u;OdRJ9IkbVb&n95;GT)9<@tx5L7U8woU;91~1RZ?J5>8iNy$63d~lhE)vNv(eJ4odfm%1ONM_-$n>XfH|ut1W=4A z^k?{Ig&zOhHDi5mGy2vv0qc^s0CLd_m2V;Amu=w~_F&w#gn%1e_twIS>s|q@>JI@B z<6w|}eccz-xmeC#MNrfpn6dEF zy**bgMz`NQPj2)yuThM3w6O4p%#rRo#20d>&17BxeqdjqCROdaMYWK1q< z1=*3+1?NOrQBU5xzYQ7G>H-B*?3FITV1pTGY_L<7(NBobRV#!3UyWfXCJCh>6 ztt&=XT6QUadGHnUg8aU6Vq&Cu+wC}bo7l@Wc7v@$(Z8zr6X^jCmHYm8^7L^KB*m@FKr`F<9AACkz=f<(49nm2J^d^zYXoo=S& zzpGHY#LIHC!zZ{CVWkk%qL1l_q}I~&U<7e6l>`>sT$HLKBw2P)-~=}l3Gj=d|= z&vdUyyYGq|oxDkonDU}!$22$hF{~J&73CN+QotQCGnAxPyGXsR}Uc8rF;{tl)33S@Hu|vE-)(D=~8+=}~_a2f<->%!*&~0|N!j z)`EkG;YRYtQ|+bAUAB|?jkb~79^Nc{(!oO#ayR2|DyyZg!{3u;T}{BxIPSr3_mlW} zLlFTQaLw3c9ylUZK74f)+PP!5oO&$PT;8#RWpr8=`u==#%cj#~D~2ri0>zjR>M(J6 zMGHrnWk=N|?rd0v0`cW@DmWw#2Fm3#1)CBGRP69uJ*MCD$E{c$=vEjMOwx>Pm(kwM zmr!Zx3ffvJXma|)s89NrwTC)ndG+t`p6#Jmt>6%N`5aUc)4tfB5Aw^IAVC7IMwOdN zD);?vF44aD<5p_-g9Zze-xJBT0900&cj$Bgz?uM37l7iR0R}@T4kPgsEkeES++eG4`m*S%F{s()MRM}(Tlp}CRd`@X0n3>}XaWb~pun13k9K~1 z$wdY;R-kh&5{F7V2TYfkOw3PC6tG-C67?{OsG%pbZ$i@5!ekQt(QL)VV0--iJ}*ys;I%Z@SE#3oGht z|Exqaa7Ijg){to|mInM*ia|80g`rRv8>HU1&BRYNdDF7;(oEWuQczX(F?i4RV2l>g z1Ai1oQwKu_L3A&3hh@Gxx>w7cpgPx}CNzMB5n2l@VnQ>6ARY6ROh5C}wcwZ_Mc18N zQKB}Gl5Pfa(+eh&P#Hs_cK5~?<-O9tkDE$wE!N>}bK>#gE}?x~DxyC)_e*aa>&w1+ z<8##b_^opMyfOP8(yR>lmY{<#kWvN)_~~}^8`5JXSsk6tS&ew zVuGIh@!Ofwp1mEU595zO<&Y$-hLo&NKXa+#;^~=|!~L(K&2x{)lfKxBk_{It68`Lx zUw*Tr`LVs<$??QIU~{*&+{duWTO!HY0^AWZqoF11n2mp9;KR$hnsLew6Us>CH%`#_ zf3SqY3mB)xtT2(R3|2hDW$xP#u#!HY4?SqWO3Yj>$?BKg_^{D8FmeKBYr#Rpa3h&H z>SJlN^Br>MfuE%BtHwz$madYvkLrg%AHG;}SNFqPjy0EVE-1#&wRw!6Hxw7G4Y=HE zH1BSfFE5+06_x*|ryRe*Xqh+g7IbuHf#u>u52H4%H(GAJC$jul*=Eazk}&2V#*X$G^@JeYv)|F z-PEl zPQYv}4J2Z?k=*xWuGD$;XQbOn7l}V`Ns1rynN)smy3~K%DQWSOP4LT)ufj_*C*Us% zo&3C^hyV?k^o-{3r!1A7Rm0F*UmZoJ^dICsTTYi5Ca*yWzaL#bYjCV(%~Q|HjRy`Y zzcTv_ck(xJ>Dvn2sV4+bm?`w9drj&GawA9xc(vV7obuYcXjbVq%ilZvU@7mzAM>TX z`xP+mT0%ge?q%k(jnSF^Uir1^UfM>J`hji+4*|+E|Lo_}cF>4NH|g&s&(sI2;30rH z1axQbg#|{gK|;XTZJ$WhDKAK6_oPTIiocd#xpxzO{kB%p0%InAcFcIZY;l1!<@}v^ k|HJ%$m)%±YZm|Nn|t`~Lv>w_*Oz5O2i)0yzEu0MuZx`Tzg` literal 0 HcmV?d00001 From ead27a36fcd3754360a8fb6c29a8ec4b8fd7858d Mon Sep 17 00:00:00 2001 From: acalejos Date: Mon, 22 Jan 2024 23:52:30 -0500 Subject: [PATCH 14/21] WIP plotting livebook --- .gitignore | 3 +- README.md | 2 +- lib/exgboost.ex | 8 +- lib/exgboost/plotting.ex | 223 ++++++++++++++++++++++--------- lib/exgboost/plotting/style.ex | 14 ++ lib/exgboost/plotting/styles.ex | 2 +- mix.exs | 4 +- mix.lock | 3 + notebooks/plotting.livemd | 230 ++++++++++++++++++++++++++++++++ 9 files changed, 417 insertions(+), 72 deletions(-) create mode 100644 notebooks/plotting.livemd diff --git a/.gitignore b/.gitignore index f214385..15fd0cd 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,5 @@ erl_crash.dump .elixir_ls/ .tool-versions .vscode/ -checksum.exs \ No newline at end of file +checksum.exs +.DS_Store \ No newline at end of file diff --git a/README.md b/README.md index 8da4acd..0a8b04c 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ billions of examples. ```elixir def deps do [ - {:exgboost, "~> 0.3"} + {:exgboost, "~> 0.5"} ] end ``` diff --git a/lib/exgboost.ex b/lib/exgboost.ex index 7755b61..dfaf6c9 100644 --- a/lib/exgboost.ex +++ b/lib/exgboost.ex @@ -16,7 +16,7 @@ defmodule EXGBoost do ```elixir def deps do [ - {:exgboost, "~> 0.4"} + {:exgboost, "~> 0.5"} ] end ``` @@ -92,7 +92,7 @@ defmodule EXGBoost do preds = EXGBoost.train(X, y) |> EXGBoost.predict(X) ``` - ## Serliaztion + ## Serialization A Booster can be serialized to a file using `EXGBoost.write_*` and loaded from a file using `EXGBoost.read_*`. The file format can be specified using the `:format` option @@ -557,11 +557,13 @@ defmodule EXGBoost do ## Options * `:format` - the format to export the graphic as, must be either of: `:json`, `:html`, `:png`, `:svg`, `:pdf`. By default the format is inferred from the file extension. * `:local_npm_prefix` - a relative path pointing to a local npm project directory where the necessary npm packages are installed. For instance, in Phoenix projects you may want to pass local_npm_prefix: "assets". By default the npm packages are searched for in the current directory and globally. + * `:path` - the path to save the graphic to. If not provided, the graphic is returned as a VegaLite spec. + * `:opts` - additional options to pass to `EXGBoost.Plotting.plot/2`. See `EXGBoost.Plotting` for more information. """ def plot_tree(booster, opts \\ []) do {path, opts} = Keyword.pop(opts, :path) {save_opts, opts} = Keyword.split(opts, [:format, :local_npm_prefix]) - vega = Plotting.to_vega(booster, opts) + vega = Plotting.plot(booster, opts) if path != nil do VegaLite.Export.save!(vega, path, save_opts) diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex index 05a147a..6caec2f 100644 --- a/lib/exgboost/plotting.ex +++ b/lib/exgboost/plotting.ex @@ -475,35 +475,35 @@ defmodule EXGBoost.Plotting do style: [ doc: "The style to use for the visualization. Refer to `EXGBoost.Plotting.Styles` for a list of available styles.", - default: :solarized_light, - type: {:in, Keyword.keys(@styles)} + default: Application.compile_env(:exgboost, [:plotting, :style], :solarized_light), + type: {:or, [{:in, Keyword.keys(@styles)}, {:in, [nil, false]}, :keyword_list]} ], rankdir: [ doc: "Determines the direction of the graph.", type: {:in, [:tb, :lr, :bt, :rl]}, - default: :tb + default: Application.compile_env(:exgboost, [:plotting, :rankdir], :tb) ], autosize: [ doc: "Determines if the visualization should automatically resize when the window size changes", type: {:in, ["fit", "pad", "fit-x", "fit-y", "none"]}, - default: "fit" + default: Application.compile_env(:exgboost, [:plotting, :autosize], "fit") ], background: [ doc: "The background color of the visualization. Accepts a valid CSS color string. For example: `#f304d3`, `#ccc`, `rgb(253, 12, 134)`, `steelblue.`", type: :string, - default: "#f5f5f5" + default: Application.compile_env(:exgboost, [:plotting, :background], "#f5f5f5") ], height: [ doc: "Height of the plot in pixels", type: :pos_integer, - default: 400 + default: Application.compile_env(:exgboost, [:plotting, :height], 400) ], width: [ doc: "Width of the plot in pixels", type: :pos_integer, - default: 600 + default: Application.compile_env(:exgboost, [:plotting, :width], 600) ], padding: [ doc: @@ -519,7 +519,7 @@ defmodule EXGBoost.Plotting do bottom: [type: :pos_integer] ] ]}, - default: 30 + default: Application.compile_env(:exgboost, [:plotting, :padding], 30) ], leaves: [ doc: "Specifies characteristics of leaf nodes", @@ -529,18 +529,42 @@ defmodule EXGBoost.Plotting do @mark_opts ++ [ doc: @mark_text_doc, - default: @default_leaf_text + default: + deep_merge_kw( + @default_leaf_text, + Application.compile_env( + :exgboost, + [:plotting, :leaves, :text], + [] + ) + ) ], rect: @mark_opts ++ [ doc: @mark_rect_doc, - default: @default_leaf_rect + default: + deep_merge_kw( + @default_leaf_rect, + Application.compile_env( + :exgboost, + [:plotting, :leaves, :rect], + [] + ) + ) ] ], default: [ - text: @default_leaf_text, - rect: @default_leaf_rect + text: + deep_merge_kw( + @default_leaf_text, + Application.compile_env(:exgboost, [:plotting, :leaves, :text], []) + ), + rect: + deep_merge_kw( + @default_leaf_rect, + Application.compile_env(:exgboost, [:plotting, :leaves, :rect], []) + ) ] ], splits: [ @@ -551,36 +575,76 @@ defmodule EXGBoost.Plotting do @mark_opts ++ [ doc: @mark_text_doc, - default: @default_split_text + default: + deep_merge_kw( + @default_split_text, + Application.compile_env( + :exgboost, + [:plotting, :splits, :text], + [] + ) + ) ], rect: @mark_opts ++ [ doc: @mark_rect_doc, - default: @default_split_rect + default: + deep_merge_kw( + @default_split_rect, + Application.compile_env( + :exgboost, + [:plotting, :splits, :rect], + [] + ) + ) ], children: @mark_opts ++ [ doc: @mark_text_doc, - default: @default_split_children + default: + deep_merge_kw( + @default_split_children, + Application.compile_env( + :exgboost, + [:plotting, :splits, :children], + [] + ) + ) ] ], default: [ - text: @default_split_text, - rect: @default_split_rect, - children: @default_split_children + text: + deep_merge_kw( + @default_split_text, + Application.compile_env(:exgboost, [:plotting, :splits, :text], []) + ), + rect: + deep_merge_kw( + @default_split_rect, + Application.compile_env(:exgboost, [:plotting, :splits, :rect], []) + ), + children: + deep_merge_kw( + @default_split_children, + Application.compile_env( + :exgboost, + [:plotting, :splits, :children], + [] + ) + ) ] ], node_width: [ doc: "The width of each node in pixels", type: :pos_integer, - default: 100 + default: Application.compile_env(:exgboost, [:plotting, :node_width], 100) ], node_height: [ doc: "The height of each node in pixels", type: :pos_integer, - default: 45 + default: Application.compile_env(:exgboost, [:plotting, :node_heigh], 45) ], space_between: [ doc: "The space between the rectangular marks in pixels.", @@ -589,15 +653,18 @@ defmodule EXGBoost.Plotting do nodes: [ doc: "Space between marks within the same depth of the tree.", type: :pos_integer, - default: 10 + default: Application.compile_env(:exgboost, [:plotting, :space_between, :nodes], 10) ], levels: [ doc: "Space between each rank / depth of the tree.", type: :pos_integer, - default: 100 + default: Application.compile_env(:exgboost, [:plotting, :space_between, :levels], 100) ] ], - default: [nodes: 10, levels: 100] + default: [ + nodes: Application.compile_env(:exgboost, [:plotting, :space_between, :nodes], 10), + levels: Application.compile_env(:exgboost, [:plotting, :space_between, :levels], 100) + ] ], yes: [ doc: "Specifies characteristics of links between nodes where the split condition is true", @@ -607,18 +674,34 @@ defmodule EXGBoost.Plotting do @mark_opts ++ [ doc: @mark_path_doc, - default: @default_yes_path + default: + deep_merge_kw( + @default_yes_path, + Application.compile_env(:exgboost, [:plotting, :yes, :path], []) + ) ], text: @mark_opts ++ [ doc: @mark_text_doc, - default: @default_yes_text + default: + deep_merge_kw( + @default_yes_text, + Application.compile_env(:exgboost, [:plotting, :yes, :text], []) + ) ] ], default: [ - path: @default_yes_path, - text: @default_yes_text + path: + deep_merge_kw( + @default_yes_path, + Application.compile_env(:exgboost, [:plotting, :yes, :path], []) + ), + text: + deep_merge_kw( + @default_yes_text, + Application.compile_env(:exgboost, [:plotting, :yes, :text], []) + ) ] ], no: [ @@ -629,36 +712,52 @@ defmodule EXGBoost.Plotting do @mark_opts ++ [ doc: @mark_path_doc, - default: @default_no_path + default: + deep_merge_kw( + @default_no_path, + Application.compile_env(:exgboost, [:plotting, :no, :path], []) + ) ], text: @mark_opts ++ [ doc: @mark_text_doc, - default: @default_no_text + default: + deep_merge_kw( + @default_no_text, + Application.compile_env(:exgboost, [:plotting, :no, :path], []) + ) ] ], default: [ - path: @default_no_path, - text: @default_no_text + path: + deep_merge_kw( + @default_no_path, + Application.compile_env(:exgboost, [:plotting, :no, :path], []) + ), + text: + deep_merge_kw( + @default_no_text, + Application.compile_env(:exgboost, [:plotting, :no, :path], []) + ) ] ], validate: [ doc: "Whether to validate the Vega specification against the Vega schema", type: :boolean, - default: true + default: Application.compile_env(:exgboost, [:plotting, :validate], true) ], index: [ doc: "The zero-indexed index of the tree to plot. If `nil`, plots all trees using a dropdown selector to switch between trees", type: {:or, [nil, :non_neg_integer]}, - default: 0 + default: Application.compile_env(:exgboost, [:plotting, :index], 0) ], depth: [ doc: "The depth of the tree to plot. If `nil`, plots all levels (cick on a node to expand/collapse)", type: {:or, [nil, :non_neg_integer]}, - default: nil + default: Application.compile_env(:exgboost, [:plotting, :depth], nil) ] ] @@ -672,13 +771,13 @@ defmodule EXGBoost.Plotting do ingested by Vega, and apply some default configuations that only account for a subset of the configurations that can be set by a Vega spec directly. The functions provided in this module are designed to have opinionated defaults that can be used to quickly visualize a model, but the full power of Vega is available by using the - `to_tabular/1` function to convert the model into a tabular format, and then using the `to_vega/2` function + `to_tabular/1` function to convert the model into a tabular format, and then using the `plot/2` function to convert the tabular format into a Vega specification. ## Default Vega Specification The default Vega specification is designed to be a good starting point for visualizing a model, but it is - possible to customize the specification by passing in a map of Vega properties to the `to_vega/2` function. + possible to customize the specification by passing in a map of Vega properties to the `plot/2` function. Refer to `Custom Vega Specifications` for more details on how to do this. By default, the Vega specification includes the following entities to use for rendering the model: @@ -692,19 +791,19 @@ defmodule EXGBoost.Plotting do ## Custom Vega Specifications The default Vega specification is designed to be a good starting point for visualizing a model, but it is - possible to customize the specification by passing in a map of Vega properties to the `to_vega/2` function. + possible to customize the specification by passing in a map of Vega properties to the `plot/2` function. You can find the full list of Vega properties [here](https://vega.github.io/vega/docs/specification/). It is suggested that you use the data attributes provided by the default specification as a starting point, since they provide the necessary data transformation to convert the model into a tree structure that can be visualized by Vega. - If you would like to customize the default specification, you can use `EXGBoost.Plotting.to_vega/1` to get the default + If you would like to customize the default specification, you can use `EXGBoost.Plotting.plot/1` to get the default specification, and then modify it as needed. Once you have a custom specification, you can pass it to `VegaLite.from_json/1` to create a new `VegaLite` struct, after which you can use the functions provided by the `VegaLite` module to render the model. ## Specification Validation - You can optionally validate your specification against the Vega schema by passing the `validate: true` option to `to_vega/2`. + You can optionally validate your specification against the Vega schema by passing the `validate: true` option to `plot/2`. This will raise an error if the specification is invalid. This is useful if you are creating a custom specification and want to ensure that it is valid. Note that this will only validate the specification against the Vega schema, and not against the VegaLite schema. This requires the [`ex_json_schema`] package to be installed. @@ -781,20 +880,6 @@ defmodule EXGBoost.Plotting do end end - defp deep_merge_kw(a, b) do - Keyword.merge(a, b, fn - _key, val_a, val_b when is_list(val_a) and is_list(val_b) -> - deep_merge_kw(val_a, val_b) - - key, val_a, val_b -> - if Keyword.has_key?(b, key) do - val_b - else - val_a - end - end) - end - defp new_node(params = %{}) do Map.merge( %{ @@ -858,12 +943,16 @@ defmodule EXGBoost.Plotting do opts = NimbleOptions.validate!(opts, @plotting_schema) opts = - unless opts[:style] in [nil, false] do - style = apply(__MODULE__, opts[:style], []) + cond do + opts[:style] in [nil, false] -> + opts + + Keyword.keyword?(opts[:style]) -> + deep_merge_kw(opts, opts[:style]) - deep_merge_kw(opts, style) - else - opts + true -> + style = apply(__MODULE__, opts[:style], []) + deep_merge_kw(opts, style) end %{ @@ -1179,17 +1268,21 @@ defmodule EXGBoost.Plotting do } end - def to_vega(booster, opts \\ []) do + def plot(booster, opts \\ []) do spec = get_data_spec(booster, opts) opts = NimbleOptions.validate!(opts, @plotting_schema) opts = - unless opts[:style] in [nil, false] do - style = apply(__MODULE__, opts[:style], []) + cond do + opts[:style] in [nil, false] -> + opts - deep_merge_kw(opts, style) - else - opts + Keyword.keyword?(opts[:style]) -> + deep_merge_kw(opts, opts[:style]) + + true -> + style = apply(__MODULE__, opts[:style], []) + deep_merge_kw(opts, style) end defaults = get_defaults() @@ -1673,6 +1766,6 @@ end defimpl Kino.Render, for: EXGBoost.Booster do def to_livebook(booster) do - EXGBoost.Plotting.to_vega(booster) |> Kino.Render.to_livebook() + EXGBoost.Plotting.plot(booster) |> Kino.Render.to_livebook() end end diff --git a/lib/exgboost/plotting/style.ex b/lib/exgboost/plotting/style.ex index 002a62b..8019945 100644 --- a/lib/exgboost/plotting/style.ex +++ b/lib/exgboost/plotting/style.ex @@ -7,6 +7,20 @@ defmodule EXGBoost.Plotting.Style do end end + def deep_merge_kw(a, b) do + Keyword.merge(a, b, fn + _key, val_a, val_b when is_list(val_a) and is_list(val_b) -> + deep_merge_kw(val_a, val_b) + + key, val_a, val_b -> + if Keyword.has_key?(b, key) do + val_b + else + val_a + end + end) + end + defmacro style(style_name, do: body) do quote do def unquote(style_name)(), do: unquote(body) diff --git a/lib/exgboost/plotting/styles.ex b/lib/exgboost/plotting/styles.ex index cd6fdf8..3e9ede3 100644 --- a/lib/exgboost/plotting/styles.ex +++ b/lib/exgboost/plotting/styles.ex @@ -8,7 +8,7 @@ defmodule EXGBoost.Plotting.Styles do

#{name}

         
-        #{EXGBoost.Plotting.to_vega(@bst, style: name, height: 200, width: 300).spec |> Jason.encode!()}
+        #{EXGBoost.plot_tree(@bst, style: name, height: 200, width: 300).spec |> Jason.encode!()}
         
       
diff --git a/mix.exs b/mix.exs index eb55c5b..0cf41de 100644 --- a/mix.exs +++ b/mix.exs @@ -52,7 +52,9 @@ defmodule EXGBoost.MixProject do {:ex_json_schema, "~> 0.10.2"}, {:httpoison, "~> 2.0", runtime: false}, {:vega_lite, "~> 0.1"}, - {:kino, "~> 0.11"} + {:kino, "~> 0.11"}, + {:scidata, "~> 0.1", only: :dev}, + {:kino_vega_lite, "~> 0.1.9", only: :dev} ] end diff --git a/mix.lock b/mix.lock index ad391b0..6c17710 100644 --- a/mix.lock +++ b/mix.lock @@ -19,6 +19,7 @@ "jason": {:hex, :jason, "1.4.0", "e855647bc964a44e2f67df589ccf49105ae039d4179db7f6271dfd3843dc27e6", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "79a3791085b2a0f743ca04cec0f7be26443738779d09302e01318f97bdb82121"}, "json_ptr": {:hex, :json_ptr, "1.2.0", "07a757d1b0a86b7fd73f5a5b26d4d41c5bc7d5be4a3e3511d7458293ce70b8bb", [:mix], [{:jason, "~> 1.4.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "e58704ac304cbf3832c0ac161e76479e7b05f75427991ddd57e19b307ae4aa05"}, "kino": {:hex, :kino, "0.11.3", "c48a3f8bcad5ec103884aee119dc01594c0ead7c755061853147357c01885f2c", [:mix], [{:fss, "~> 0.1.0", [hex: :fss, repo: "hexpm", optional: false]}, {:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "027e1f634ce175e9092d9bfc632c911ca50ad735eecc77770bfaf4b4a8b2932f"}, + "kino_vega_lite": {:hex, :kino_vega_lite, "0.1.11", "d3c2a00b3685b95f91833920d06cc9b1fd7fb293a2663d89affe9aaec16a5b77", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.8", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "5ccd9148ce7cfcc95a137e12596cd8b95b371e9ea107e745bc262c39c5d8d48e"}, "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, @@ -27,12 +28,14 @@ "mime": {:hex, :mime, "2.0.5", "dc34c8efd439abe6ae0343edbb8556f4d63f178594894720607772a041b04b02", [:mix], [], "hexpm", "da0d64a365c45bc9935cc5c8a7fc5e49a0e0f9932a761c55d6c52b142780a05c"}, "mimerl": {:hex, :mimerl, "1.2.0", "67e2d3f571088d5cfd3e550c383094b47159f3eee8ffa08e64106cdf5e981be3", [:rebar3], [], "hexpm", "f278585650aa581986264638ebf698f8bb19df297f66ad91b18910dfc6e19323"}, "mint": {:hex, :mint, "1.5.1", "8db5239e56738552d85af398798c80648db0e90f343c8469f6c6d8898944fb6f", [:mix], [{:castore, "~> 0.1.0 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:hpax, "~> 0.1.1", [hex: :hpax, repo: "hexpm", optional: false]}], "hexpm", "4a63e1e76a7c3956abd2c72f370a0d0aecddc3976dea5c27eccbecfa5e7d5b1e"}, + "nimble_csv": {:hex, :nimble_csv, "1.2.0", "4e26385d260c61eba9d4412c71cea34421f296d5353f914afe3f2e71cce97722", [:mix], [], "hexpm", "d0628117fcc2148178b034044c55359b26966c6eaa8e2ce15777be3bbc91b12a"}, "nimble_options": {:hex, :nimble_options, "1.0.2", "92098a74df0072ff37d0c12ace58574d26880e522c22801437151a159392270e", [:mix], [], "hexpm", "fd12a8db2021036ce12a309f26f564ec367373265b53e25403f0ee697380f1b8"}, "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, "nimble_pool": {:hex, :nimble_pool, "1.0.0", "5eb82705d138f4dd4423f69ceb19ac667b3b492ae570c9f5c900bb3d2f50a847", [:mix], [], "hexpm", "80be3b882d2d351882256087078e1b1952a28bf98d0a287be87e4a24a710b67a"}, "nx": {:hex, :nx, "0.5.3", "6ad5534f9b82429dafa12329952708c2fdd6ab01b306e86333fdea72383147ee", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "d1072fc4423809ed09beb729e73c200ce177ddecac425d9eb6fba643669623ec"}, "parse_trans": {:hex, :parse_trans, "3.4.1", "6e6aa8167cb44cc8f39441d05193be6e6f4e7c2946cb2759f015f8c56b76e5ff", [:rebar3], [], "hexpm", "620a406ce75dada827b82e453c19cf06776be266f5a67cff34e1ef2cbb60e49a"}, "req": {:hex, :req, "0.4.5", "2071bbedd280f107b9e33e1ddff2beb3991ec1ae06caa2cca2ab756393d8aca5", [:mix], [{:brotli, "~> 0.3.1", [hex: :brotli, repo: "hexpm", optional: true]}, {:ezstd, "~> 1.0", [hex: :ezstd, repo: "hexpm", optional: true]}, {:finch, "~> 0.9", [hex: :finch, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:mime, "~> 1.6 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.0", [hex: :nimble_csv, repo: "hexpm", optional: true]}, {:plug, "~> 1.0", [hex: :plug, repo: "hexpm", optional: true]}], "hexpm", "dd23e9c7303ddeb2dee09ff11ad8102cca019e38394456f265fb7b9655c64dd8"}, + "scidata": {:hex, :scidata, "0.1.8", "80b1a470efb4f4a2fc2a7d26711217948d8bc2e17f636d41d846035b44f0ce8a", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.1", [hex: :nimble_csv, repo: "hexpm", optional: false]}, {:stb_image, "~> 0.4", [hex: :stb_image, repo: "hexpm", optional: true]}], "hexpm", "8fb125a2a55f52fea78f708f7bbd7bb3b8c233bf3943b411bc659252bcb12bab"}, "ssl_verify_fun": {:hex, :ssl_verify_fun, "1.1.7", "354c321cf377240c7b8716899e182ce4890c5938111a1296add3ec74cf1715df", [:make, :mix, :rebar3], [], "hexpm", "fe4c190e8f37401d30167c8c405eda19469f34577987c76dde613e838bbc67f8"}, "table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, diff --git a/notebooks/plotting.livemd b/notebooks/plotting.livemd new file mode 100644 index 0000000..9bb4b2e --- /dev/null +++ b/notebooks/plotting.livemd @@ -0,0 +1,230 @@ +# Plotting in EXGBoost + +```elixir +# Mix.install([ +# {:exgboost, path: Path.join(__DIR__, ".."), env: :dev}, +# ], +# lockfile: :exgboost) + +Mix.install([ + {:exgboost, path: "/Users/andres/Documents/exgboost"}, + {:nx, "~> 0.5"}, + {:scidata, "~> 0.1"}, + {:scholar, "~> 0.1"}, + {:kino_vega_lite, "~> 0.1.9"} +]) + +# This assumed you launch this livebook from its location in the exgboost/notebooks folder +``` + +## Introduction + +Much of the utility from decision trees come from their intuitiveness and ability to inform dcisions outside of the confines of a black-box model. A decision tree can be easily translated to a series of actions that can be taken on behalf of the stakeholder to achieve the desired outcome. This makes them especially useful in business decisions, where people might still want to have the final say but be as informed as possible. Additionally, tabular data is still quite popular in the business domain, which conforms to the required input for decision trees. + +Decision trees can be used for both regression and classification tasks, but classification tends to be what is most associated with decision trees. + + + +This notebook will go over some of the details of the `EXGBoost.Plotting` module, including using preconfiged styles, custom styling, as well as customizing the entire vidualization. + +## Plotting APIs + +There are 2 main APIs exposed to control plotting in `EXGBoost`: + +* Top-level API (`EXGBoost.plot_tree/2`) + + * Using predefined styles + * Defining custom styles + * Mix of the first 2 + +* `EXBoost.Plotting` module API + + * Use the Vega `data` spec defined in `EXGBoost.get_data_spec/2` + * Define your own Vega spec using the data from either `EXGBoost.Plotting.to_tabular/1` or some other means + + We will walk through each of these in detail. + +Regardless of which API you choose to use, it is helpful to understand how the plotting module works (althought the higher-level API you choose to work with the less important it becomes). + +## Implementation Details + +The plotting functionality provided in `EXGBoost` is powered by the [Vega](https://vega.github.io/vega/) JavaScript library and the Elixir [`VegaLite`](https://hexdocs.pm/vega_lite/VegaLite.html) library which provides the piping to interop with the JavaScript libraries. **We do not actually much use the Elixir API provided by the Elixir VegaLite library. It is mainly used for the purposes of rendering.** + +Vega is a plotting library built on top of the very powerful [D3](https://d3js.org/) JavaScript library. Vega visualizations are defined according to the respective JSON Schema specification. Vega-Lite offers a [reduced schema](https://vega.github.io/schema/vega-lite/v5.json) compared to the [full Vega spec](https://vega.github.io/schema/vega/v5.json). `EXGBoost.Plotting` leverages several transforms which are not available in the reduced Vega-Lite schema, which is the reason for targeting the lower-level API. + +For these reasons, unfortunately we could not just implement plotting for `EXGBoost` as a composable Vega-Lite pipeline. This makes working synamically with the spec a bit more unwieldly, but much care was taken to still make the high-level plotting API extensible, and if needed you can go straight to defining your own JSON spec. + +## Setup Data + +We will still be using the Iris dataset for this notebook, but if you want more details about the process of training and evaluating a model please check out the `Iris Classification with Gradient Boosting` notebook. + +So let's proceed by setting up the Iris dataset. + +```elixir +{x, y} = Scidata.Iris.download() +data = Enum.zip(x, y) |> Enum.shuffle() +{train, test} = Enum.split(data, ceil(length(data) * 0.8)) +{x_train, y_train} = Enum.unzip(train) +{x_test, y_test} = Enum.unzip(test) + +x_train = Nx.tensor(x_train) +y_train = Nx.tensor(y_train) + +x_test = Nx.tensor(x_test) +y_test = Nx.tensor(y_test) +``` + +## Train Your Booster + +Now go ahead and train your booster. We will use `early_stopping_rounds: 1` because we're not interested in the accuracy of the booster for this demonstration (*Note that we need to set `evals` to use early stopping*). + +You will notice that `EXGBoost` also provides an implementation for `Kino.Render` so that `EXGBoost.Booster`s are rendered as a plot by default. + +```elixir +booster = + EXGBoost.train( + x_train, + y_train, + num_class: 3, + objective: :multi_softprob, + num_boost_rounds: 10, + evals: [{x_train, y_train, "training"}], + verbose_eval: false, + early_stopping_rounds: 1 + ) +``` + +You'll notice that the plot doesn't display any labels to the features in the splits, and instead only shows features labelled as "f2" etc. If you provide feature labels during training, your plot will show the splits using the feature labels. + +```elixir +booster = + EXGBoost.train(x_train, y_train, + num_class: 3, + objective: :multi_softprob, + num_boost_rounds: 10, + evals: [{x_train, y_train, "training"}], + verbose_eval: false, + feature_name: ["sepal length", "sepal width", "petal length", "petal width"], + early_stopping_rounds: 1 + ) +``` + +## Top-Level API + +`EXGBoost.plot_tree/2` is the quickest way to customize the output of the plot. + + + +### Styles + +There are a set of provided pre-configured settings for the top-level API that you may optionally use. You can refer to the `EXGBoost.Plottings.Styles` docs to see a gallery of each style in action. You can specify a style with the `:style` option in `EXGBoost.plot_tree/2`. + +You can still specify custom settings along with using a style. Most styles only specify a subset of the total possible settings, so you are free to specify any other allowed keys and they will be merged with the style. However, any options set by the style **do** take precedence over options. + +For example, let's look at the `:solarized_dark` style: + +```elixir +EXGBoost.Plotting.solarized_dark() |> Keyword.take([:background, :height]) |> IO.inspect() +EXGBoost.plot_tree(booster, style: :solarized_dark) +``` + +You can see that it defines a background color of `#002b36` but does not restrict what the height must be. + +```elixir +EXGBoost.plot_tree(booster, style: :solarized_dark, background: "white", height: 200) +``` + +We specified both `:background` and `:height` here, but only `:height` was changed because it was not specified in the style. + +If you want to leverage a style but have the flexibility to change something it defines, you can always get the style specification as a `Keyword` which can be passed to `EXGBoost.plot_tree/2` manually, making any needed changes yourself, like so: + +```elixir +custom_style = EXGBoost.Plotting.solarized_dark() |> Keyword.put(:background, "white") +EXGBoost.plot_tree(booster, style: custom_style) +``` + +The benefits of using a style is you still get to leverage all of the defaults provided by the API. Look at the difference between changing the background as we just did by specifying the style versus using the style as the new `opts` argument: + +```elixir +custom_style = EXGBoost.Plotting.solarized_dark() |> Keyword.put(:background, "white") +EXGBoost.plot_tree(booster, [style: nil] ++ custom_style) +``` + +As you can see, it maintained the pieces that were **EXPLICITLY** set by the style, but lost some of the defaults that improve the plot appearance. + +Obviously, if you wish to specify all parameters yourself, this shouldn't be an issue. + + + +### Configuration + +You can also set defaults for the top-level API using an `Application` configuration for `EXGBoost` under the `:plotting` key. Since the defaults are collected from your configuration file at compile-time, anything you set during runtime, even if you set it to the Application environment, will not be registered as defaults. + +For example, if you just want to change the default pre-configured style you can do: + + + +```elixir +Mix.install([ + {:exgboost, path: Path.join(__DIR__, ".."), env: :dev}, +], + config: + [ + exgboost: [ + plotting: [ + style: :solarized_dark, + ]] + ], + lockfile: :exgboost) +``` + +You can also make one-off changes to any of the settings with this method. In effect, this turns into a default custom style. **Just make sure to set `style: nil` to ensure that the `style` option doesn't supercede any of your settings.** Here's an example of that: + + + +```elixir + default_style = + [ + style: nil, + background: "#3f3f3f", + leaves: [ + # Foreground + text: [fill: "#dcdccc", font_size: 12, font_style: "normal", font_weight: "normal"], + # Comment + rect: [fill: "#7f9f7f", stroke: "#7f9f7f"] + ], + splits: [ + # Foreground + text: [fill: "#dcdccc", font_size: 12, font_style: "normal", font_weight: "bold"], + # Comment + rect: [fill: "#7f9f7f", stroke: "#7f9f7f"], + # Selection + children: [fill: "#2b2b2b", stroke: "#2b2b2b"] + ], + yes: [ + # Green + text: [fill: "#7f9f7f"], + # Selection + path: [stroke: "#2b2b2b"] + ], + no: [ + # Red + text: [fill: "#cc9393"], + # Selection + path: [stroke: "#2b2b2b"] + ] + ] + +Mix.install([ + {:exgboost, path: Path.join(__DIR__, ".."), env: :dev}, +], +config: + [ + exgboost: [ + plotting: default_style, + ] + ] +) +``` + +**NOTE: When you specify a parameter in the configuration, it is merged with the defaults which is different from runtime behavior.** From d80378c952a001b45b47e3eeecfbe2a17e14d197 Mon Sep 17 00:00:00 2001 From: acalejos Date: Tue, 23 Jan 2024 23:31:17 -0500 Subject: [PATCH 15/21] WIP plotting notebook --- lib/exgboost/plotting.ex | 65 +++++----- lib/exgboost/plotting/style.ex | 20 +++- notebooks/plotting.livemd | 209 +++++++++++++++++++++++++++++++-- 3 files changed, 247 insertions(+), 47 deletions(-) diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex index 6caec2f..a86c4ac 100644 --- a/lib/exgboost/plotting.ex +++ b/lib/exgboost/plotting.ex @@ -413,13 +413,11 @@ defmodule EXGBoost.Plotting do align: :center, baseline: :middle, font_size: 13, - font: "Calibri", - fill: :black + font: "Calibri" ] @default_leaf_rect [ corner_radius: 2, - fill: :teal, opacity: 1 ] @@ -427,47 +425,38 @@ defmodule EXGBoost.Plotting do align: :center, baseline: :middle, font_size: 13, - font: "Calibri", - fill: :black + font: "Calibri" ] @default_split_rect [ corner_radius: 2, - fill: :teal, opacity: 1 ] @default_split_children [ align: :right, baseline: :middle, - fill: :black, font: "Calibri", font_size: 13 ] - @default_yes_path [ - stroke: :red - ] + @default_yes_path [] @default_yes_text [ align: :center, baseline: :middle, font_size: 13, font: "Calibri", - fill: :black, text: "yes" ] - @default_no_path [ - stroke: :black - ] + @default_no_path [] @default_no_text [ align: :center, baseline: :middle, font_size: 13, font: "Calibri", - fill: :black, text: "no" ] @@ -475,7 +464,7 @@ defmodule EXGBoost.Plotting do style: [ doc: "The style to use for the visualization. Refer to `EXGBoost.Plotting.Styles` for a list of available styles.", - default: Application.compile_env(:exgboost, [:plotting, :style], :solarized_light), + default: Application.compile_env(:exgboost, [:plotting, :style], :dracula), type: {:or, [{:in, Keyword.keys(@styles)}, {:in, [nil, false]}, :keyword_list]} ], rankdir: [ @@ -948,12 +937,13 @@ defmodule EXGBoost.Plotting do opts Keyword.keyword?(opts[:style]) -> - deep_merge_kw(opts, opts[:style]) + deep_merge_kw(opts[:style], opts, @defaults) true -> style = apply(__MODULE__, opts[:style], []) - deep_merge_kw(opts, style) + deep_merge_kw(style, opts, @defaults) end + |> then(&deep_merge_kw(@defaults, &1)) %{ "$schema" => "https://vega.github.io/schema/vega/v5.json", @@ -1278,14 +1268,13 @@ defmodule EXGBoost.Plotting do opts Keyword.keyword?(opts[:style]) -> - deep_merge_kw(opts, opts[:style]) + deep_merge_kw(opts[:style], opts, @defaults) true -> style = apply(__MODULE__, opts[:style], []) - deep_merge_kw(opts, style) + deep_merge_kw(style, opts, @defaults) end - - defaults = get_defaults() + |> then(&deep_merge_kw(@defaults, &1)) # Try to account for non-default node height / width to adjust spacing # between nodes and levels as a quality of life improvement @@ -1293,17 +1282,17 @@ defmodule EXGBoost.Plotting do opts = cond do opts[:rankdir] in [:lr, :rl] and - opts[:space_between][:levels] == defaults[:space_between][:levels] -> + opts[:space_between][:levels] == @defaults[:space_between][:levels] -> put_in( opts[:space_between][:levels], - defaults[:space_between][:levels] + (opts[:node_width] - defaults[:node_width]) + @defaults[:space_between][:levels] + (opts[:node_width] - @defaults[:node_width]) ) opts[:rankdir] in [:tb, :bt] and - opts[:space_between][:levels] == defaults[:space_between][:levels] -> + opts[:space_between][:levels] == @defaults[:space_between][:levels] -> put_in( opts[:space_between][:levels], - defaults[:space_between][:levels] + (opts[:node_height] - defaults[:node_height]) + @defaults[:space_between][:levels] + (opts[:node_height] - @defaults[:node_height]) ) true -> @@ -1313,17 +1302,17 @@ defmodule EXGBoost.Plotting do opts = cond do opts[:rankdir] in [:lr, :rl] and - opts[:space_between][:nodes] == defaults[:space_between][:nodes] -> + opts[:space_between][:nodes] == @defaults[:space_between][:nodes] -> put_in( opts[:space_between][:nodes], - defaults[:space_between][:nodes] + (opts[:node_height] - defaults[:node_height]) + @defaults[:space_between][:nodes] + (opts[:node_height] - @defaults[:node_height]) ) opts[:rankdir] in [:tb, :bt] and - opts[:space_between][:nodes] == defaults[:space_between][:nodes] -> + opts[:space_between][:nodes] == @defaults[:space_between][:nodes] -> put_in( opts[:space_between][:nodes], - defaults[:space_between][:nodes] + (opts[:node_width] - defaults[:node_width]) + @defaults[:space_between][:nodes] + (opts[:node_width] - @defaults[:node_width]) ) true -> @@ -1381,15 +1370,19 @@ defmodule EXGBoost.Plotting do %{ "encode" => %{ "update" => - Map.merge( + deep_merge_maps( %{ "x" => %{ "signal" => - "(scale('xscale', datum.source.x+(nodeWidth/3)) + scale('xscale', datum.target.x)) / 2" + "(scale('xscale', datum.source.x#{cond do + opts[:rankdir] in [:tb, :bt, :lr] -> ~c"-nodeWidth/4" + opts[:rankdir] in [:rl] -> ~c"+nodeWidth/4" + true -> ~c"" + end}) + scale('xscale', datum.target.x)) / 2" }, "y" => %{ "signal" => - "(scale('yscale', datum.source.y) + scale('yscale', datum.target.y)) / 2 - (scaledNodeHeight/2)" + "(scale('yscale', datum.source.y#{if opts[:rankdir] in [:lr, :rl], do: ~c"-nodeWidth/3", else: ~c""}) + scale('yscale', datum.target.y)) / 2 - (scaledNodeHeight/2)" } }, format_mark(opts[:yes][:text]) @@ -1405,7 +1398,11 @@ defmodule EXGBoost.Plotting do %{ "x" => %{ "signal" => - "(scale('xscale', datum.source.x+(nodeWidth/3)) + scale('xscale', datum.target.x)) / 2" + "(scale('xscale', datum.source.x#{cond do + opts[:rankdir] in [:tb, :bt, :lr] -> ~c"-nodeWidth/4" + opts[:rankdir] in [:rl] -> ~c"+nodeWidth/4" + true -> ~c"" + end}) + scale('xscale', datum.target.x)) / 2" }, "y" => %{ "signal" => diff --git a/lib/exgboost/plotting/style.ex b/lib/exgboost/plotting/style.ex index 8019945..d181099 100644 --- a/lib/exgboost/plotting/style.ex +++ b/lib/exgboost/plotting/style.ex @@ -7,13 +7,31 @@ defmodule EXGBoost.Plotting.Style do end end - def deep_merge_kw(a, b) do + def deep_merge_kw(a, b, ignore_set \\ []) do Keyword.merge(a, b, fn _key, val_a, val_b when is_list(val_a) and is_list(val_b) -> deep_merge_kw(val_a, val_b) key, val_a, val_b -> if Keyword.has_key?(b, key) do + if Keyword.has_key?(ignore_set, key) and Keyword.get(ignore_set, key) == val_b do + val_a + else + val_b + end + else + val_a + end + end) + end + + def deep_merge_maps(b, a) do + Map.merge(a, b, fn + _key, val_a, val_b when is_map(val_a) and is_map(val_b) -> + deep_merge_maps(val_a, val_b) + + key, val_a, val_b -> + if Map.has_key?(b, key) do val_b else val_a diff --git a/notebooks/plotting.livemd b/notebooks/plotting.livemd index 9bb4b2e..22a18aa 100644 --- a/notebooks/plotting.livemd +++ b/notebooks/plotting.livemd @@ -113,13 +113,199 @@ booster = `EXGBoost.plot_tree/2` is the quickest way to customize the output of the plot. +This API uses [Vega `Mark`s](https://vega.github.io/vega/docs/marks/) to describe the plot. Each of the following `Mark` options accepts any of the valid keys from their respective `Mark` type as described in the Vega documentation. + +**Please note that these are passed as a `Keyword`, and as such the keys must be atoms rather than strings as the Vega docs show. Valid options for this API are `camel_cased` atoms as opposed to the `pascalCased` strings the Vega docs describe, so if you wish to pass `"fontSize"` as the Vega docs show, you would instead pass it as `font_size:` in this API.** + +The plot is composed of the following parts: + +* Top-level keys: Options controlling parts of the plot outside of direct control of a `Mark`, such as `:padding`, `:autosize`, etc. Accepts any Vega top-level [top-level key](https://vega.github.io/vega/docs/specification/) in addition to several specific to this API (scuh as `:style` and `:depth`). +* `:leaves`: `Mark` specifying the leaf nodes of the tree + * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) + * `:rect`: [Rect Mark](https://vega.github.io/vega/docs/marks/rect/) +* `:splits` `Mark` specifying the split (or inner / decision) nodes of the tree + * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) + * `:rect`: [Rect Mark](https://vega.github.io/vega/docs/marks/rect/) + * `:children`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) for the child count +* `:yes` + * `:path`: [Path Mark](https://vega.github.io/vega/docs/marks/path/) + * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) +* `:no` + * `:path`: [Path Mark](https://vega.github.io/vega/docs/marks/path/) + * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) + +`EXGBoost.plot_tree/2` defaults to outputting a `VegaLite` struct. If you pass the `:path` option it will save to a file instead. + +If you want to add any marks to the underlying plot you will have to use the lower-level `EXGBoost.Plotting` API, as the top-level API is only capable of customizing these marks. + +### Top-Level Keys + + + +`EXGBoost` supports changing the direction of the plots through the `:rankdir` option. Avaiable directions are `[:tb, :bt, :lr, :rl]`, with top-to-bottom (`:tb`) being the default. + +```elixir +EXGBoost.plot_tree(booster, rankdir: :bt) +``` + +By default, plotting only shows one (the first) tree, but seeing as a `Booster` is really an ensemble of trees you can choose which tree to plot through the `:index` option, or set to `nil` to have a dropdown box to select the tree. + +```elixir +EXGBoost.plot_tree(booster, rankdir: :lr, index: 4) +``` + +You'll also notice that the plot is interactive, with support for scrolling, zooming, and collapsing sections of the tree. If you click on a split node you will toggle the visibility of its descendents, and the rest of the tree will fill the canvas. + +You can also use the `:depth` option to programatically set the max depth to display in the tree: + +```elixir +EXGBoost.plot_tree(booster, rankdir: :lr, index: 4, depth: 3) +``` + +One way to affect the canvas size is by controlling the padding. + +You can add padding to all side by specifying an integer for the `:padding` option + +```elixir +EXGBoost.plot_tree(booster, rankdir: :rl, index: 4, depth: 3, padding: 50) +``` + +Or specify padding for each side: + +```elixir +EXGBoost.plot_tree(booster, + rankdir: :lr, + index: 4, + depth: 3, + padding: [top: 5, bottom: 25, left: 50, right: 10] +) +``` + +You can also specify the canvas size using the `:width` and `:height` options: + +```elixir +EXGBoost.plot_tree(booster, + rankdir: :lr, + index: 4, + depth: 3, + width: 500, + height: 500 +) +``` + +But do note that changing the padding of a canvas does change the size, even if you specify the size using `:height` and `:width` + +```elixir +EXGBoost.plot_tree(booster, + rankdir: :lr, + index: 4, + depth: 3, + width: 500, + height: 500, + padding: 10 +) +``` + +You can change the dimensions of all nodes through the `:node_height` and `:node_width` options: + +```elixir +EXGBoost.plot_tree(booster, rankdir: :lr, index: 4, depth: 3, node_width: 60, node_height: 60) +``` + +Or change the space between nodes using the `:space_between` option. + +**Note that the size of the accompanying nodes and/or text will change to accomodate the new `:space_between` option while trying to maintain the canvas size.** + +```elixir +EXGBoost.plot_tree( + booster, + rankdir: :lr, + index: 4, + depth: 3, + space_between: [nodes: 200] +) +``` + +So if you want to add the space between while not changing the size of the nodes you might need to manually adjust the canvas size: + +```elixir +EXGBoost.plot_tree( + booster, + rankdir: :lr, + index: 4, + depth: 3, + space_between: [nodes: 200], + height: 800 +) +``` + +```elixir +EXGBoost.plot_tree( + booster, + rankdir: :lr, + index: 4, + depth: 3, + space_between: [levels: 200] +) +``` + +### Mark Options + +The options controlling the appearance of individual marks all conform to a similar API. You can refer to the options and pre-defined defaults for a subset of the allowed options, but you can also pass other options so long as they are allowed by the Vega Mark spec (as defined [here](#cell-y5oxrrri4daa6xt5)) + +```elixir +EXGBoost.plot_tree( + booster, + rankdir: :bt, + index: 4, + depth: 3, + space_between: [levels: 200], + yes: [ + text: [font_size: 18, fill: :teal] + ], + no: [ + text: [font_size: 20] + ], + node_width: 100 +) +``` + +Most marks accept an `:opacity` option that you can use to effectively hide the mark: + +```elixir +EXGBoost.plot_tree( + booster, + rankdir: :lr, + index: 4, + depth: 3, + splits: [ + text: [opacity: 0], + rect: [opacity: 0], + children: [opacity: 1] + ] +) +``` + +And `text` marks accept normal text options such as `:fill`, `:font_size`, and `:font`: + +```elixir +EXGBoost.plot_tree( + booster, + node_width: 250, + splits: [ + text: [font: "Helvetica Neue", font_size: 20, fill: "orange"] + ], + space_between: [levels: 20] +) +``` + ### Styles There are a set of provided pre-configured settings for the top-level API that you may optionally use. You can refer to the `EXGBoost.Plottings.Styles` docs to see a gallery of each style in action. You can specify a style with the `:style` option in `EXGBoost.plot_tree/2`. -You can still specify custom settings along with using a style. Most styles only specify a subset of the total possible settings, so you are free to specify any other allowed keys and they will be merged with the style. However, any options set by the style **do** take precedence over options. +You can still specify custom settings along with using a style. Most styles only specify a subset of the total possible settings, but you are free to specify any other allowed keys and they will be merged with the style. Any options passed explicitly to the option **does** take precedence over the style options. For example, let's look at the `:solarized_dark` style: @@ -134,28 +320,21 @@ You can see that it defines a background color of `#002b36` but does not restric EXGBoost.plot_tree(booster, style: :solarized_dark, background: "white", height: 200) ``` -We specified both `:background` and `:height` here, but only `:height` was changed because it was not specified in the style. +We specified both `:background` and `:height` here, and the background specified in the option supercedes the one from the style. -If you want to leverage a style but have the flexibility to change something it defines, you can always get the style specification as a `Keyword` which can be passed to `EXGBoost.plot_tree/2` manually, making any needed changes yourself, like so: +You can also always get the style specification as a `Keyword` which can be passed to `EXGBoost.plot_tree/2` manually, making any needed changes yourself, like so: ```elixir custom_style = EXGBoost.Plotting.solarized_dark() |> Keyword.put(:background, "white") EXGBoost.plot_tree(booster, style: custom_style) ``` -The benefits of using a style is you still get to leverage all of the defaults provided by the API. Look at the difference between changing the background as we just did by specifying the style versus using the style as the new `opts` argument: +You can also programatically check which styles are available: ```elixir -custom_style = EXGBoost.Plotting.solarized_dark() |> Keyword.put(:background, "white") -EXGBoost.plot_tree(booster, [style: nil] ++ custom_style) +EXGBoost.Plotting.get_styles() ``` -As you can see, it maintained the pieces that were **EXPLICITLY** set by the style, but lost some of the defaults that improve the plot appearance. - -Obviously, if you wish to specify all parameters yourself, this shouldn't be an issue. - - - ### Configuration You can also set defaults for the top-level API using an `Application` configuration for `EXGBoost` under the `:plotting` key. Since the defaults are collected from your configuration file at compile-time, anything you set during runtime, even if you set it to the Application environment, will not be registered as defaults. @@ -228,3 +407,9 @@ config: ``` **NOTE: When you specify a parameter in the configuration, it is merged with the defaults which is different from runtime behavior.** + +At any point, you can check what your default settings are by using `EXGBoost.Plotting.get_defaults/0` + +```elixir +EXGBoost.Plotting.get_defaults() +``` From 53ed029ccddf49f87ce28db1e4479da469e3632f Mon Sep 17 00:00:00 2001 From: acalejos Date: Wed, 24 Jan 2024 22:40:01 -0500 Subject: [PATCH 16/21] Update docs --- lib/exgboost.ex | 28 ++++++++++++++++++++ lib/exgboost/nif.ex | 1 - lib/exgboost/plotting.ex | 6 +++++ lib/exgboost/plotting/style.ex | 2 ++ mix.exs | 8 ++++-- mix.lock | 35 +++++++++--------------- notebooks/plotting.livemd | 47 +++++++++++++++++++++++++++++---- test/model.json | Bin 0 -> 30612 bytes 8 files changed, 97 insertions(+), 30 deletions(-) create mode 100644 test/model.json diff --git a/lib/exgboost.ex b/lib/exgboost.ex index dfaf6c9..71644e8 100644 --- a/lib/exgboost.ex +++ b/lib/exgboost.ex @@ -113,6 +113,34 @@ defmodule EXGBoost do - `config` - Save the configuration only. - `weights` - Save the model parameters only. Use this when you want to save the model to a format that can be ingested by other XGBoost APIs. - `model` - Save both the model parameters and the configuration. + + ## Plotting + + `EXGBoost.plot_tree/2` is the primary entry point for plotting a tree from a trained model. + It accepts an `EXGBoost.Booster` struct (which is the output of `EXGBoost.train/2`). + `EXGBoost.plot_tree/2` returns a VegaLite spec that can be rendered in a notebook or saved to a file. + `EXGBoost.plot_tree/2` also accepts a keyword list of options that can be used to configure the plotting process. + + See `EXGBoost.Plotting` for more detail on plotting. + + You can see available styles by running `EXGBoost.Plotting.get_styles()` or refer to the `EXGBoost.Plotting.Styles` + documentation for a gallery of the styles. + + ## Kino & Livebook Integration + + `EXGBoost` integrates with [Kino](https://hexdocs.pm/kino/Kino.html) and [Livebook](https://livebook.dev/) + to provide a rich interactive experience for data scientists. + + EXGBoost implements the `Kino.Render` protocol for `EXGBoost.Booster` structs. This allows you to render + a Booster in a Livebook notebook. Under the hood, `EXGBoost` uses [Vega-Lite](https://vega.github.io/vega-lite/) + and [Kino Vega-Lite](https://hexdocs.pm/kino_vega_lite/Kino.VegaLite.html) to render the Booster. + + See the [`Plotting in EXGBoost`](notebooks/plotting.livemd) Notebook for an example of how to use `EXGBoost` with `Kino` and `Livebook`. + + ## Examples + + See the example Notebooks in the left sidebar (under the `Pages` tab) for more examples and tutorials + on how to use EXGBoost. """ alias EXGBoost.ArrayInterface diff --git a/lib/exgboost/nif.ex b/lib/exgboost/nif.ex index ba09dfc..bbed9e0 100644 --- a/lib/exgboost/nif.ex +++ b/lib/exgboost/nif.ex @@ -100,7 +100,6 @@ defmodule EXGBoost.NIF do def dmatrix_create_from_file(_file_uri, _silent), do: :erlang.nif_error(:not_implemented) - @since "0.4.0" def dmatrix_create_from_uri(_config), do: :erlang.nif_error(:not_implemented) @spec dmatrix_create_from_mat(binary, integer(), integer(), float()) :: diff --git a/lib/exgboost/plotting.ex b/lib/exgboost/plotting.ex index a86c4ac..50719fb 100644 --- a/lib/exgboost/plotting.ex +++ b/lib/exgboost/plotting.ex @@ -828,10 +828,13 @@ defmodule EXGBoost.Plotting do `:background` as an option, the `:background` option will be ignored since the `:solarized_light` style defines its own value for `:background`. """ + @spec get_schema() :: ExJsonSchema.Schema.Root.t() def get_schema(), do: @schema + @spec get_defaults() :: Keyword.t() def get_defaults(), do: @defaults + @spec get_styles() :: [{atom(), [style(), ...]}, ...] def get_styles(), do: @styles defp validate_spec(spec) do @@ -904,6 +907,7 @@ defmodule EXGBoost.Plotting do - depth: The depth of the node (root node is depth 1) - leaf: The leaf value if it is a leaf node """ + @spec to_tabular(EXGBoost.Booster.t()) :: [map()] def to_tabular(booster) do booster |> EXGBoost.Booster.get_dump(format: :json) @@ -928,6 +932,7 @@ defmodule EXGBoost.Plotting do * `:rankdir` - Determines the direction of the graph. Accepts one of `:tb`, `:lr`, `:bt`, or `:rl`. Defaults to `:tb` """ + @spec get_data_spec(booster :: EXGBoost.Booster.t(), opts :: keyword()) :: map() def get_data_spec(booster, opts \\ []) do opts = NimbleOptions.validate!(opts, @plotting_schema) @@ -1258,6 +1263,7 @@ defmodule EXGBoost.Plotting do } end + @spec plot(EXGBoost.Booster.t(), Keyword.t()) :: VegaLite.t() def plot(booster, opts \\ []) do spec = get_data_spec(booster, opts) opts = NimbleOptions.validate!(opts, @plotting_schema) diff --git a/lib/exgboost/plotting/style.ex b/lib/exgboost/plotting/style.ex index d181099..b1ca860 100644 --- a/lib/exgboost/plotting/style.ex +++ b/lib/exgboost/plotting/style.ex @@ -2,6 +2,7 @@ defmodule EXGBoost.Plotting.Style do @moduledoc false defmacro __using__(_opts) do quote do + @type style :: Keyword.t() import EXGBoost.Plotting.Style Module.register_attribute(__MODULE__, :styles, accumulate: true) end @@ -41,6 +42,7 @@ defmodule EXGBoost.Plotting.Style do defmacro style(style_name, do: body) do quote do + @spec unquote(style_name)() :: style def unquote(style_name)(), do: unquote(body) Module.put_attribute(__MODULE__, :styles, {unquote(style_name), unquote(body)}) end diff --git a/mix.exs b/mix.exs index 0cf41de..67e29c6 100644 --- a/mix.exs +++ b/mix.exs @@ -19,6 +19,9 @@ defmodule EXGBoost.MixProject do start_permanent: Mix.env() == :prod, compilers: [:elixir_make] ++ Mix.compilers(), deps: deps(), + name: "EXGBoost", + source_url: "https://github.com/acalejos/exgboost", + homepage_url: "https://github.com/acalejos/exgboost", docs: docs(), package: package(), preferred_cli_env: [ @@ -46,7 +49,7 @@ defmodule EXGBoost.MixProject do {:nimble_options, "~> 1.0"}, {:nx, "~> 0.5"}, {:jason, "~> 1.3"}, - {:ex_doc, "~> 0.29.0", only: :docs}, + {:ex_doc, "~> 0.31.0", only: :docs}, {:cc_precompiler, "~> 0.1.0", runtime: false}, {:exterval, "0.1.0"}, {:ex_json_schema, "~> 0.10.2"}, @@ -82,7 +85,8 @@ defmodule EXGBoost.MixProject do extras: [ "notebooks/compiled_benchmarks.livemd", "notebooks/iris_classification.livemd", - "notebooks/quantile_prediction_interval.livemd" + "notebooks/quantile_prediction_interval.livemd", + "notebooks/plotting.livemd" ], groups_for_extras: [ Notebooks: Path.wildcard("notebooks/*.livemd") diff --git a/mix.lock b/mix.lock index 6c17710..73d39f7 100644 --- a/mix.lock +++ b/mix.lock @@ -1,41 +1,32 @@ %{ - "castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"}, - "cc_precompiler": {:hex, :cc_precompiler, "0.1.7", "77de20ac77f0e53f20ca82c563520af0237c301a1ec3ab3bc598e8a96c7ee5d9", [:mix], [{:elixir_make, "~> 0.7.3", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "2768b28bf3c2b4f788c995576b39b8cb5d47eb788526d93bd52206c1d8bf4b75"}, + "castore": {:hex, :castore, "0.1.22", "4127549e411bedd012ca3a308dede574f43819fe9394254ca55ab4895abfa1a2", [:mix], [], "hexpm", "c17576df47eb5aa1ee40cc4134316a99f5cad3e215d5c77b8dd3cfef12a22cac"}, + "cc_precompiler": {:hex, :cc_precompiler, "0.1.9", "e8d3364f310da6ce6463c3dd20cf90ae7bbecbf6c5203b98bf9b48035592649b", [:mix], [{:elixir_make, "~> 0.7", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "9dcab3d0f3038621f1601f13539e7a9ee99843862e66ad62827b0c42b2f58a54"}, "certifi": {:hex, :certifi, "2.12.0", "2d1cca2ec95f59643862af91f001478c9863c2ac9cb6e2f89780bfd8de987329", [:rebar3], [], "hexpm", "ee68d85df22e554040cdb4be100f33873ac6051387baf6a8f6ce82272340ff1c"}, "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, "decimal": {:hex, :decimal, "2.1.1", "5611dca5d4b2c3dd497dec8f68751f1f1a54755e8ed2a966c2633cf885973ad6", [:mix], [], "hexpm", "53cfe5f497ed0e7771ae1a475575603d77425099ba5faef9394932b35020ffcc"}, - "earmark_parser": {:hex, :earmark_parser, "1.4.32", "fa739a0ecfa34493de19426681b23f6814573faee95dfd4b4aafe15a7b5b32c6", [:mix], [], "hexpm", "b8b0dd77d60373e77a3d7e8afa598f325e49e8663a51bcc2b88ef41838cca755"}, - "elixir_make": {:hex, :elixir_make, "0.7.7", "7128c60c2476019ed978210c245badf08b03dbec4f24d05790ef791da11aa17c", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}], "hexpm", "5bc19fff950fad52bbe5f211b12db9ec82c6b34a9647da0c2224b8b8464c7e6c"}, - "ex_doc": {:hex, :ex_doc, "0.29.4", "6257ecbb20c7396b1fe5accd55b7b0d23f44b6aa18017b415cb4c2b91d997729", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "2c6699a737ae46cb61e4ed012af931b57b699643b24dabe2400a8168414bc4f5"}, + "earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"}, + "elixir_make": {:hex, :elixir_make, "0.7.8", "505026f266552ee5aabca0b9f9c229cbb496c689537c9f922f3eb5431157efc7", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "7a71945b913d37ea89b06966e1342c85cfe549b15e6d6d081e8081c493062c07"}, + "ex_doc": {:hex, :ex_doc, "0.31.1", "8a2355ac42b1cc7b2379da9e40243f2670143721dd50748bf6c3b1184dae2089", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "3178c3a407c557d8343479e1ff117a96fd31bafe52a039079593fb0524ef61b0"}, "ex_json_schema": {:hex, :ex_json_schema, "0.10.2", "7c4b8c1481fdeb1741e2ce66223976edfb9bccebc8014f6aec35d4efe964fb71", [:mix], [{:decimal, "~> 2.0", [hex: :decimal, repo: "hexpm", optional: false]}], "hexpm", "37f43be60f8407659d4d0155a7e45e7f406dab1f827051d3d35858a709baf6a6"}, - "exonerate": {:hex, :exonerate, "1.1.3", "96c3249f0358567151d595a49d612849a7a57930575a1eb781c178f2e84b1559", [:mix], [{:finch, "~> 0.15", [hex: :finch, repo: "hexpm", optional: true]}, {:idna, "~> 6.1.1", [hex: :idna, repo: "hexpm", optional: true]}, {:jason, "~> 1.4.0", [hex: :jason, repo: "hexpm", optional: false]}, {:json_ptr, "~> 1.0", [hex: :json_ptr, repo: "hexpm", optional: false]}, {:match_spec, "~> 0.3.1", [hex: :match_spec, repo: "hexpm", optional: false]}, {:pegasus, "~> 0.2.2", [hex: :pegasus, repo: "hexpm", optional: true]}, {:req, "~> 0.3", [hex: :req, repo: "hexpm", optional: true]}, {:yaml_elixir, "~> 2.7", [hex: :yaml_elixir, repo: "hexpm", optional: true]}], "hexpm", "cb750e303ae79987153eecb474ccd011f45e3dea7c4f988d8286facdf56db1d2"}, "exterval": {:hex, :exterval, "0.1.0", "d35ec43b0f260239f859665137fac0974f1c6a8d50bf8d52b5999c87c67c63e5", [:mix], [], "hexpm", "8ed444558a501deec6563230e3124cdf242c413a95eb9ca9f39de024ad779d7f"}, - "finch": {:hex, :finch, "0.16.0", "40733f02c89f94a112518071c0a91fe86069560f5dbdb39f9150042f44dcfb1a", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:mime, "~> 1.0 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:mint, "~> 1.3", [hex: :mint, repo: "hexpm", optional: false]}, {:nimble_options, "~> 0.4 or ~> 1.0", [hex: :nimble_options, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 0.2.6 or ~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "f660174c4d519e5fec629016054d60edd822cdfe2b7270836739ac2f97735ec5"}, "fss": {:hex, :fss, "0.1.1", "9db2344dbbb5d555ce442ac7c2f82dd975b605b50d169314a20f08ed21e08642", [:mix], [], "hexpm", "78ad5955c7919c3764065b21144913df7515d52e228c09427a004afe9c1a16b0"}, "hackney": {:hex, :hackney, "1.20.1", "8d97aec62ddddd757d128bfd1df6c5861093419f8f7a4223823537bad5d064e2", [:rebar3], [{:certifi, "~> 2.12.0", [hex: :certifi, repo: "hexpm", optional: false]}, {:idna, "~> 6.1.0", [hex: :idna, repo: "hexpm", optional: false]}, {:metrics, "~> 1.0.0", [hex: :metrics, repo: "hexpm", optional: false]}, {:mimerl, "~> 1.1", [hex: :mimerl, repo: "hexpm", optional: false]}, {:parse_trans, "3.4.1", [hex: :parse_trans, repo: "hexpm", optional: false]}, {:ssl_verify_fun, "~> 1.1.0", [hex: :ssl_verify_fun, repo: "hexpm", optional: false]}, {:unicode_util_compat, "~> 0.7.0", [hex: :unicode_util_compat, repo: "hexpm", optional: false]}], "hexpm", "fe9094e5f1a2a2c0a7d10918fee36bfec0ec2a979994cff8cfe8058cd9af38e3"}, - "hpax": {:hex, :hpax, "0.1.2", "09a75600d9d8bbd064cdd741f21fc06fc1f4cf3d0fcc335e5aa19be1a7235c84", [:mix], [], "hexpm", "2c87843d5a23f5f16748ebe77969880e29809580efdaccd615cd3bed628a8c13"}, "httpoison": {:hex, :httpoison, "2.2.1", "87b7ed6d95db0389f7df02779644171d7319d319178f6680438167d7b69b1f3d", [:mix], [{:hackney, "~> 1.17", [hex: :hackney, repo: "hexpm", optional: false]}], "hexpm", "51364e6d2f429d80e14fe4b5f8e39719cacd03eb3f9a9286e61e216feac2d2df"}, "idna": {:hex, :idna, "6.1.1", "8a63070e9f7d0c62eb9d9fcb360a7de382448200fbbd1b106cc96d3d8099df8d", [:rebar3], [{:unicode_util_compat, "~> 0.7.0", [hex: :unicode_util_compat, repo: "hexpm", optional: false]}], "hexpm", "92376eb7894412ed19ac475e4a86f7b413c1b9fbb5bd16dccd57934157944cea"}, - "jason": {:hex, :jason, "1.4.0", "e855647bc964a44e2f67df589ccf49105ae039d4179db7f6271dfd3843dc27e6", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "79a3791085b2a0f743ca04cec0f7be26443738779d09302e01318f97bdb82121"}, - "json_ptr": {:hex, :json_ptr, "1.2.0", "07a757d1b0a86b7fd73f5a5b26d4d41c5bc7d5be4a3e3511d7458293ce70b8bb", [:mix], [{:jason, "~> 1.4.0", [hex: :jason, repo: "hexpm", optional: false]}], "hexpm", "e58704ac304cbf3832c0ac161e76479e7b05f75427991ddd57e19b307ae4aa05"}, - "kino": {:hex, :kino, "0.11.3", "c48a3f8bcad5ec103884aee119dc01594c0ead7c755061853147357c01885f2c", [:mix], [{:fss, "~> 0.1.0", [hex: :fss, repo: "hexpm", optional: false]}, {:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "027e1f634ce175e9092d9bfc632c911ca50ad735eecc77770bfaf4b4a8b2932f"}, + "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, + "kino": {:hex, :kino, "0.12.3", "a5f48a243c60a7ac18ba23869f697b1c775fc7794e8cd55dd248ba33c6fe9445", [:mix], [{:fss, "~> 0.1.0", [hex: :fss, repo: "hexpm", optional: false]}, {:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "a6dfa3d54ba0edec9ca6e5940154916b381901001f171c85a2d8c67869dbc2d8"}, "kino_vega_lite": {:hex, :kino_vega_lite, "0.1.11", "d3c2a00b3685b95f91833920d06cc9b1fd7fb293a2663d89affe9aaec16a5b77", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.8", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "5ccd9148ce7cfcc95a137e12596cd8b95b371e9ea107e745bc262c39c5d8d48e"}, - "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, + "makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"}, "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, - "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, - "match_spec": {:hex, :match_spec, "0.3.1", "9fb2b9313dbd2c5aa6210a8c11eea94e0531a0982e463d9467445e8615081af0", [:mix], [], "hexpm", "225e90aa02c8023b12cf47e33c8a7307ac88b31f461b43d93e040f4641704557"}, + "makeup_erlang": {:hex, :makeup_erlang, "0.1.3", "d684f4bac8690e70b06eb52dad65d26de2eefa44cd19d64a8095e1417df7c8fd", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "b78dc853d2e670ff6390b605d807263bf606da3c82be37f9d7f68635bd886fc9"}, "metrics": {:hex, :metrics, "1.0.1", "25f094dea2cda98213cecc3aeff09e940299d950904393b2a29d191c346a8486", [:rebar3], [], "hexpm", "69b09adddc4f74a40716ae54d140f93beb0fb8978d8636eaded0c31b6f099f16"}, - "mime": {:hex, :mime, "2.0.5", "dc34c8efd439abe6ae0343edbb8556f4d63f178594894720607772a041b04b02", [:mix], [], "hexpm", "da0d64a365c45bc9935cc5c8a7fc5e49a0e0f9932a761c55d6c52b142780a05c"}, "mimerl": {:hex, :mimerl, "1.2.0", "67e2d3f571088d5cfd3e550c383094b47159f3eee8ffa08e64106cdf5e981be3", [:rebar3], [], "hexpm", "f278585650aa581986264638ebf698f8bb19df297f66ad91b18910dfc6e19323"}, - "mint": {:hex, :mint, "1.5.1", "8db5239e56738552d85af398798c80648db0e90f343c8469f6c6d8898944fb6f", [:mix], [{:castore, "~> 0.1.0 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:hpax, "~> 0.1.1", [hex: :hpax, repo: "hexpm", optional: false]}], "hexpm", "4a63e1e76a7c3956abd2c72f370a0d0aecddc3976dea5c27eccbecfa5e7d5b1e"}, "nimble_csv": {:hex, :nimble_csv, "1.2.0", "4e26385d260c61eba9d4412c71cea34421f296d5353f914afe3f2e71cce97722", [:mix], [], "hexpm", "d0628117fcc2148178b034044c55359b26966c6eaa8e2ce15777be3bbc91b12a"}, - "nimble_options": {:hex, :nimble_options, "1.0.2", "92098a74df0072ff37d0c12ace58574d26880e522c22801437151a159392270e", [:mix], [], "hexpm", "fd12a8db2021036ce12a309f26f564ec367373265b53e25403f0ee697380f1b8"}, - "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, - "nimble_pool": {:hex, :nimble_pool, "1.0.0", "5eb82705d138f4dd4423f69ceb19ac667b3b492ae570c9f5c900bb3d2f50a847", [:mix], [], "hexpm", "80be3b882d2d351882256087078e1b1952a28bf98d0a287be87e4a24a710b67a"}, - "nx": {:hex, :nx, "0.5.3", "6ad5534f9b82429dafa12329952708c2fdd6ab01b306e86333fdea72383147ee", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "d1072fc4423809ed09beb729e73c200ce177ddecac425d9eb6fba643669623ec"}, + "nimble_options": {:hex, :nimble_options, "1.1.0", "3b31a57ede9cb1502071fade751ab0c7b8dbe75a9a4c2b5bbb0943a690b63172", [:mix], [], "hexpm", "8bbbb3941af3ca9acc7835f5655ea062111c9c27bcac53e004460dfd19008a99"}, + "nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"}, + "nx": {:hex, :nx, "0.6.4", "948d9f42f81e63fc901d243ac0a985c8bb87358be62e27826cfd67f58bc640af", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "bb9c2e2e3545b5eb4739d69046a988daaa212d127dba7d97801c291616aff6d6"}, "parse_trans": {:hex, :parse_trans, "3.4.1", "6e6aa8167cb44cc8f39441d05193be6e6f4e7c2946cb2759f015f8c56b76e5ff", [:rebar3], [], "hexpm", "620a406ce75dada827b82e453c19cf06776be266f5a67cff34e1ef2cbb60e49a"}, - "req": {:hex, :req, "0.4.5", "2071bbedd280f107b9e33e1ddff2beb3991ec1ae06caa2cca2ab756393d8aca5", [:mix], [{:brotli, "~> 0.3.1", [hex: :brotli, repo: "hexpm", optional: true]}, {:ezstd, "~> 1.0", [hex: :ezstd, repo: "hexpm", optional: true]}, {:finch, "~> 0.9", [hex: :finch, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:mime, "~> 1.6 or ~> 2.0", [hex: :mime, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.0", [hex: :nimble_csv, repo: "hexpm", optional: true]}, {:plug, "~> 1.0", [hex: :plug, repo: "hexpm", optional: true]}], "hexpm", "dd23e9c7303ddeb2dee09ff11ad8102cca019e38394456f265fb7b9655c64dd8"}, - "scidata": {:hex, :scidata, "0.1.8", "80b1a470efb4f4a2fc2a7d26711217948d8bc2e17f636d41d846035b44f0ce8a", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.1", [hex: :nimble_csv, repo: "hexpm", optional: false]}, {:stb_image, "~> 0.4", [hex: :stb_image, repo: "hexpm", optional: true]}], "hexpm", "8fb125a2a55f52fea78f708f7bbd7bb3b8c233bf3943b411bc659252bcb12bab"}, + "scidata": {:hex, :scidata, "0.1.11", "fe3358bac7d740374b4f2a7eff6a1cb02e5ee7f87f7cdb1e8648ad93c533165f", [:mix], [{:castore, "~> 0.1", [hex: :castore, repo: "hexpm", optional: false]}, {:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:nimble_csv, "~> 1.1", [hex: :nimble_csv, repo: "hexpm", optional: false]}, {:stb_image, "~> 0.4", [hex: :stb_image, repo: "hexpm", optional: true]}], "hexpm", "90873337a9d5fe880d640517efa93d3c07e46c8ba436de44117f581800549f93"}, "ssl_verify_fun": {:hex, :ssl_verify_fun, "1.1.7", "354c321cf377240c7b8716899e182ce4890c5938111a1296add3ec74cf1715df", [:make, :mix, :rebar3], [], "hexpm", "fe4c190e8f37401d30167c8c405eda19469f34577987c76dde613e838bbc67f8"}, "table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, diff --git a/notebooks/plotting.livemd b/notebooks/plotting.livemd index 22a18aa..2305d36 100644 --- a/notebooks/plotting.livemd +++ b/notebooks/plotting.livemd @@ -1,17 +1,18 @@ # Plotting in EXGBoost ```elixir -# Mix.install([ -# {:exgboost, path: Path.join(__DIR__, ".."), env: :dev}, -# ], -# lockfile: :exgboost) +# Mix.install( +# [ +# {:exgboost, "~> 0.5", env: :dev} +# ] +# ) Mix.install([ {:exgboost, path: "/Users/andres/Documents/exgboost"}, {:nx, "~> 0.5"}, {:scidata, "~> 0.1"}, {:scholar, "~> 0.1"}, - {:kino_vega_lite, "~> 0.1.9"} + {:kino_vega_lite, "~> 0.1"} ]) # This assumed you launch this livebook from its location in the exgboost/notebooks folder @@ -413,3 +414,39 @@ At any point, you can check what your default settings are by using `EXGBoost.Pl ```elixir EXGBoost.Plotting.get_defaults() ``` + +## Low-Level API + +If you find yourself needing more granular control over your plots, you can reach towards the `EXGBoost.Plotting` module. This module houses the `EXGBoost.Plotting.plot/2` function, which is what is used under the hood from the `EXGBoost.plot_tree/2` top-level API. This module also has the `get_data_spec/2` function, as well as the `to_tabular/1` function, both of which can be used to specify your own Vega specification. Lastly, the module also houses all of the pre-configured styles, which are 0-arity functions which output the `Keyword`s containing their respective style's options that can be passed to the plotting APIs. + +Let's briefly go over the `to_tabular/1` and `get_data_spec/2` functions: + + + +The `to_tabular/1` function is used to convert a `Booster`, which is formatted as a tree structure, to a tabular format which can be ingested specifically by the [Vega Stratify transform](https://vega.github.io/vega/docs/transforms/stratify/). It returns a list of "nodes", which are just `Map`s with info about each node in the tree. + +```elixir +EXGBoost.Plotting.to_tabular(booster) |> hd +``` + +You can use this function if you want to have complete control over the visualization, and just want a bit of a head start with respect to data transformation for converting the `Booster` into a more digestible format. + + + +The `get_data_source/2` function is used if you want to use the provided [Vega data specification](https://vega.github.io/vega/docs/data/). This is for those who want to only focus on implementing your own [Vega Marks](https://vega.github.io/vega/docs/marks/), and want to leverage the data transformation pipeline that powers the top-level API. + +The data transformation used is the following pipeline: + +`to_tabular/1` -> [Filter](https://vega.github.io/vega/docs/transforms/filter/) (by tree index) -> [Stratify](https://vega.github.io/vega/docs/transforms/stratify/) -> [Tree](https://vega.github.io/vega/docs/transforms/tree/) + +```elixir +EXGBoost.Plotting.get_data_spec(booster, rankdir: :bt) +``` + +The Vega fields which are not included with `get_data_spec/2` and are included in `plot/2` are: + +* [Marks](https://vega.github.io/vega/docs/marks/) +* [Scales](https://vega.github.io/vega/docs/scales/) +* [Signals](https://vega.github.io/vega/docs/signals/) + +You can make a completely valid plot using only the Data from `get_data_specs/2` and adding the marks you need. diff --git a/test/model.json b/test/model.json new file mode 100644 index 0000000000000000000000000000000000000000..a3575a8b1fa294e8e1d07d0459a91bb8f6b7f573 GIT binary patch literal 30612 zcmdsA3w#sB_HLWDAx+;!k& zKq!ccs3<5e5%fa46tp0Uf*@Bxy(lPpm6wRGB64T8J(F#^Q{`{|x#Z68m+WRUk8@_e z@0>4_O*{Fvh}sg{$L(^+`IF(b(J9$IF3H0#B<4vDd$H4JE0la5Id>3LY*OfS=R4hA zFI-5-mqxoGY<*hTNiJXwuAKmz_4Z%O<>Q9eJK zBA{lh(_Wa9rx}XPNP-$LWLLh8Fvahwp1H?f07LJEB|*Y2}XI(rM; zPCwwHdUR`8%T3~jRN&Su604L*yCRhG*$MTfA~pg3`7zbjAk*Q0lUWi+-y~j**U-w} zG`Wy)DYrm&=Gi7la(;ov05&Iyw2rsPDf?_RLi)#G@vLc~+vWBV4|JzJszXO`INl>U z=zLIv9R7~jP$rL*Z{tBZj=srVxg-_Hbmd5B*Am=V?2;XBPa(6xd#~inEuaf5yZtsW z7$$p3$oEP*E0=a%M%P-uNcW{sbN3YY1bkv3#Q3TbAm&%`-(M# zL^5u~sVT$m$(OXiZD_vU>>BG57+bG^INT~v_S$nuXavb1krWxI8~p2#SV&}rOejhN zN4`H!rBPbAi+x4KJ~au`=m_4K+&N<<60gSl=cTrr-d?xEM{-&YbH!LISiE`&s-I4n zlwocC!!Xd*ZxS#lnAeOeY&4H*ri2A}R)OjPy$XC86Qk-EW)P^3|1VXO%tRTH zjd?{rkAJGnWTFgmtZb6w;(zC>;=x02hy>X=k9AiErHf0H((Nu)Z%|A%MjIEHqt znY1=+GU8Rn*Cq<@$aZ-wni(nC_+ki6i+>%6EHY(pftvtIHNC+5f&8 zwDfDGX=-_;{6Zo`e6=~! zBe}Gm&?b@R>k71ur#X^Tx4ZJG37R>|&gm`HeNJJW8%eII^~gU#hAE`Wz>0|;nw`03 zL^a%q@`uIn zn&IUelwp%MGT;0d*3K(K!NtV{6D&2ekMbcBoZ<6rBgLKr?m+2>I?&hJ0=?+*$25hErKwDzOCQuE`iQ0v=%5i^QFl~c)wPKw z$0pzg9HE~^`DUVgqhLWt9S(76M1)vH_{RYYpNFbJ#0XlbBaD;kmAWS|h{Ml@i_2rF zDOj!5S17G+%gIbT3{qSZiGtZ1uPez1NHC^mUFPLokkw=4cYy_-;JO8>NNi4^eWb0 zNDv=1m1o{GZT+hXes8}ptM_3v^4;<+sWod*QR+!c4OT2$9ob6ugy*ftPge;GnmoXK z3@f!~u7zqnyi#T`g~VqDZQ4NZk!vOh+1ew_4mZ1)ak70nH3Kg=>_g8)Ekzx!++!Wk ziC$_@pf^2cnWvC(H7#f+p$`gL`cU;BInn?g106Mzji{RqbRfAFr*#j`p}$76lvxys z9we}0dT`0E_r>S4vKoUGpSh}m#2{LzgW#}wA@B(d}&0z_r9?Zv;F^o*9dmA`K> zzc?6qx|b-zq9N8b&(F3#F{?r-zvd8gDy%4PtiuFYPwCExz^6NHriw;#i~+pE5&GiN z+lW7JHAGCDUn;!2ED)28CQ9QJGKQvwDk8~i73DB&BlM+a)4G)Qv0yG7p?#xtHELb zuZQ9OF7pf(Pb3IJQN_JV&%?8z)9{v?vET-uWfUnFAUE;qn|F~GYJtx6l47XjX25ds zSpkfwXE}!jv@STuV}K#M>t#RTC*v36-QDki%3-#tXx`*mH$%xxn<+TZdX!;~-(q+c zAz|A(%bxEtET=1`E1BPWt!plCBNRoaG9SY#_vq;hntzPo4xbscX%OY6SeC(|ZfDqj5Zbq1FAk-}~3&xe1b(T(k+zTl*mHarf2eoarc@X4%2U4Moli zqTDoDo~#(YB`W_H)N5m@(lcwl5Iu3Wa3JCbp(t~MFyy_rDiTk!uK>d<88+N{ro?Xo zE5#>1ZP-9~l56$8TYzr30@Q@(y!eUO&a&Kc&kKOg2$b=!{- zc{DaM}KH91Sc6EJ5Y_ae*bWSj_(LzDr$e?|%V5nH} zZHh3ke_N&V)-BA3uu9Nctws6CmV(bu+T=ncYte{r0ah?adVC9&WDQs@K2Lxns_|es zhY#9mn|jT584iC!ltsRrAoWS<-*OM)u6O)amtAg=Akx6 zPj23ZU$VxmYo}cG>OIz3XPTg^J1@7I<_#AHZ)wYX4J&3O0BGy2KjlaqIK<}&ZTy$S zE2(D=+@0X&v;=YHF)KAktm%3gGN-IXadUQAU$s`zYuX~r>YEI&!B{=cB}^e>Lq&2_ zH(AJ$!kIqm?VXa>ar6NlHnIsr-TzB_sH9$ep!<#Hzywhwd#MC~s<x5Ae66U%~VHW#A6i zoW(Q8RG`%pTB64@&*Nt8HlrVZ>xX^W1JUw3OR;UDUm1oXibQ1wEsU5fudX|&93PX1 zT9>_`bb7p{^~;0Dg|#CF3%A>s3cX*LRk7p>cDwQag6uT|R`?+!ShCl8+d0u7ePtbR z<5YkxYB6CsoYNy1qZbk`cG+787FG7@fQ3)-P|M!*+bwui$1Ct`?@;`7*5l~>@Jqx; zvTnhH1{lSZVKu0xmH)oxP|!D+vUg{X4MOkKe<)WCW%qZ47fZEzji&T?pah?vw8@3Y z-fOlO0Oz1{J-&rX_695$pC|RwA6U*waoTC*B{l>8<@Ibly5lfBv+V-bUq}!iw0vdW zG=7Yzw7hSkkh$R+6mj2G$^plRDCy)2mJLaJttpeYDy`ms**a==iO{n`VLpbH+|}Ef zqs&MEclgYpP5+YErAiOF+9UL)$EM+wjuEI3n$JyyK=U{NVkW z*gQ8H-!s=OrYziv9*sRNp1EQG`oB@e{!rKJkMkS)&pTp?xvw)>$X20H?^X zd@`KVBN(Gc^ow1x)`3NptU6%f6FStA)lk|8&sq2m_WUpzPrm+Ebbebqacc4m-1Xl* z#Telr`sr;OtV3wNY7OE}?Z12X7Oq8`lr|GzVLpTvOSM{y@-qqC;PaC>8!H6ASIBsT3d@)DW>zhnCrcRM)>kKE&9{e>c7MFvr_N_MQI zJhxW3EX^c*HLaymeQ+^~JZZCRIJL}bee|f(bksI$=Tk$4j>QGc$FO3Ao<^a}NCtQK z%m}t*HUGYbf=`sk{wyr+bP7eBJZe3^3CL9M2*&9#D_kTiffbXiRki!Cjlqh~Ts_HJ zZ??2AVCFc?)`J6&;b#14-@f9z%U0o+pMDCDUjC-|O@|uX>+Zf{iggE?>U&z;Vcv%h z^f`vxWb9<)h9Uw4RkD6Qewni4@Y}+#ONJ@w_A=|sH_j36OYAIk#Dj(AuWYS2G*@8G zg%$PnC2I;e!zX@_B&)c7qZO@cUt?XfU<^Q1INEX%oYNy1qet|MU9#4JMU|{NVBr%w z#F7;s+_e)s1`fajZ@Cfoo!q#~A+-Iu zm(skPWwh|#Wo0)p^7(n8l66CAC2+1WzJ*G%1}qn!C&8DjOV|8_Q_B0{E<2XNOoZk` zS`Z~`(kl+dl+#s6xowQF>&wrT^(nUtt;?qiZBOm7HXZkdV!pGpV7h-y<*NI)F(1P! zZ|P|iDp}LO9X>OHELqXc^*tza#LJ)TEi7tzN_emNEofNzT40m*2*&AUvT%{C1XikK z?L>ZhMR!>rCqEi)1Xg_J21-_XI{@4_p`A4K9A+*kl67N;-^6F1{TP>F1GY?>Bvz%g z6z6T9EPmZ-0(vCsP4SsKoMM+1OVP>gE7`bVCRr0KXI4L~y!2UX;dFJiwQOYzpqp^Fov8gGzea-<7+ZYbsw%onBFB*vgy@t5BeSdk75%O~E04FrbYo z@Ms#LH=@!X4(a|5lU%=$`->*W)}a2^zb;gC>4^s4JJga|O*&{;Ec40;e`*UDs}~#s z?=PYzYxzRWrw_Q#=z{Jqs%sZZj!noz_xG9P`UT!!#DYW>iH*cbal6Bz2S6G`0s+Vm z8g#v=qttm?xL?a?--DUu1P^@x@jme;;0;f;#kZDD!);e>7DsecaMM3h(UVWTf#Yj# zL?7K8kG_Z>Cmt0)5{=`RqN$5qsBF+)W&Wz31N3J(-5 zU^xv6m%s~tP*4FDcsCvQdsc%u<5UU7A}!Fl9*IMxodcGOPbTUoBM?~5A&GVxdDJlA zDTlJfZ~k>1HlMnN^%n}r2XVXihEYGXS0V>x3dX5xgr(6xD^d6E5)MpwvGUscI}0tF zep->dc5vl455HnvndbK~tQf)5&dp%S_{?aaoheHQigeQ^bjM|m8O?9`E@ch+vdLZ) zdv*yu7v6||0U5Wex98H=P@p$GX8oCVR)1nbI-)6#{g$+m+iq29TIU-3xd~~jn;yZi zMsK&R{nG^mR^oDpgZ})4!b)69(jxuY2?thu=4!PwWiVwYE%Y2xqk4fpfq@)m>%oD? za5G*MwLzRb?F9C=TZ63;UBqEeoW}|63?eF8E1r0A9a_C}C|c9680{|i|6!<5M1Y`f z_kQtFy0UgaQ{nA_)k@Mn$$Dq;*H*&_QE0xk*!s4(uJXY%?B`eEwUMO1-Fqn@girin zZTDt=S%WP9P=vkjbOL+|%U6tmb1hu;h<*`n_fq{z2~G7bfJG!9j`hI8Cp4X3ATp4d zPDVX4!!mJ*+;DMuG&bRK;~?Dh>@?get(|y!MhyCV$M52`pItA0(X$hpUC)`hJ^8;YAB{uE)1f zN!Ebn;`4<18C=PFYvxGt&BFJ@&(9o#nFzP--d@TRohwg|+iE@4slQNGV5UTzK(~W5VYry3liN0ps+T6(*9E!U`lSg%y*m99Z$0``?qSQyxQN->QLl zz@cmL6{8}>jQE$u?-D)~yARG0l`&(`>J8h_;t{RTp4(rB6$?oO2&!cLIqz-dxrDjG zy1k`}DZi7o>*x~8zO_??k+Xs?^E=#m?XB} z|A6@VBTdjl%OcRuCTC$CLPFnQO4in+rwg$QdMQRH`vZL8#nM2@+6tC}&(8~$tWEcH z1kOR{dVC9&WDQs@K2L%#Sug+EF3$a7qqz0}`%zJ75FfPf70jEnTth1l$DOtAYas~t z&*)ZpSl%nF+dRL*-tq&%Wd3JGzXGYUdeuqmW4mX*0?(Rd+wFQjRV^+9GRst&~SqZF|WaYq$&)ol>WW9Uhy<(f_ zQ@G17U2y!ipd6yb_ZjG?{p5#XG+avYabD32+QA>Nnhoqn}k5pd1J3CYjL{?d#V%Rvz(W68i5^(^gyuT{Z K#rXf-O8x_ny!HtI literal 0 HcmV?d00001 From 4b8362914ce701814e42b504579bbb9e24291a23 Mon Sep 17 00:00:00 2001 From: acalejos Date: Wed, 24 Jan 2024 23:00:32 -0500 Subject: [PATCH 17/21] update mix.exs docs --- lib/exgboost.ex | 20 ++++++++++ mix.exs | 15 ++++++++ notebooks/plotting.livemd | 78 ++++++++++++++++++--------------------- 3 files changed, 70 insertions(+), 43 deletions(-) diff --git a/lib/exgboost.ex b/lib/exgboost.ex index 71644e8..fffbd76 100644 --- a/lib/exgboost.ex +++ b/lib/exgboost.ex @@ -157,6 +157,7 @@ defmodule EXGBoost do Returns a map containing information about the build. """ @spec xgboost_build_info() :: map() + @doc type: :system def xgboost_build_info, do: EXGBoost.NIF.xgboost_build_info() |> Internal.unwrap!() |> Jason.decode!() @@ -166,6 +167,7 @@ defmodule EXGBoost do Returns a 3-tuple in the form of `{major, minor, patch}`. """ @spec xgboost_version() :: {integer(), integer(), integer()} | {:error, String.t()} + @doc type: :system def xgboost_version, do: EXGBoost.NIF.xgboost_version() |> Internal.unwrap!() @doc """ @@ -176,6 +178,7 @@ defmodule EXGBoost do for the full list of parameters supported in the global configuration. """ @spec set_config(map()) :: :ok | {:error, String.t()} + @doc type: :system def set_config(%{} = config) do config = EXGBoost.Parameters.validate_global!(config) EXGBoost.NIF.set_global_config(Jason.encode!(config)) |> Internal.unwrap!() @@ -189,6 +192,7 @@ defmodule EXGBoost do for the full list of parameters supported in the global configuration. """ @spec get_config() :: map() + @doc type: :system def get_config do EXGBoost.NIF.get_global_config() |> Internal.unwrap!() |> Jason.decode!() end @@ -237,6 +241,7 @@ defmodule EXGBoost do * `opts` - Refer to `EXGBoost.Parameters` for the full list of options. """ @spec train(Nx.Tensor.t(), Nx.Tensor.t(), Keyword.t()) :: EXGBoost.Booster.t() + @doc type: :train_pred def train(x, y, opts \\ []) do x = Nx.concatenate(x) y = Nx.concatenate(y) @@ -301,6 +306,7 @@ defmodule EXGBoost do Returns an Nx.Tensor containing the predictions. """ + @doc type: :train_pred def predict(%Booster{} = bst, x, opts \\ []) do x = Nx.concatenate(x) {dmat_opts, opts} = Keyword.split(opts, Internal.dmatrix_feature_opts()) @@ -331,6 +337,7 @@ defmodule EXGBoost do Returns an Nx.Tensor containing the predictions. """ + @doc type: :train_pred def inplace_predict(%Booster{} = boostr, data, opts \\ []) do opts = Keyword.validate!(opts, @@ -457,6 +464,7 @@ defmodule EXGBoost do ## Options #{NimbleOptions.docs(@write_schema)} """ + @doc type: :serialization @spec write_model(Booster.t(), String.t()) :: :ok | {:error, String.t()} def write_model(%Booster{} = booster, path, opts \\ []) do opts = NimbleOptions.validate!(opts, @write_schema) @@ -466,6 +474,7 @@ defmodule EXGBoost do @doc """ Read a model from a file and return the Booster. """ + @doc type: :serialization @spec read_model(String.t()) :: EXGBoost.Booster.t() def read_model(path) do EXGBoost.Booster.load(path, deserialize: :model) @@ -478,6 +487,7 @@ defmodule EXGBoost do #{NimbleOptions.docs(@dump_schema)} """ @spec dump_model(Booster.t()) :: binary() + @doc type: :serialization def dump_model(%Booster{} = booster, opts \\ []) do opts = NimbleOptions.validate!(opts, @dump_schema) EXGBoost.Booster.save(booster, opts ++ [serialize: :model, to: :buffer]) @@ -487,6 +497,7 @@ defmodule EXGBoost do Read a model from a buffer and return the Booster. """ @spec load_model(binary()) :: EXGBoost.Booster.t() + @doc type: :serialization def load_model(buffer) do EXGBoost.Booster.load(buffer, deserialize: :model, from: :buffer) end @@ -498,6 +509,7 @@ defmodule EXGBoost do #{NimbleOptions.docs(@write_schema)} """ @spec write_config(Booster.t(), String.t()) :: :ok | {:error, String.t()} + @doc type: :serialization def write_config(%Booster{} = booster, path, opts \\ []) do opts = NimbleOptions.validate!(opts, @write_schema) EXGBoost.Booster.save(booster, opts ++ [path: path, serialize: :config]) @@ -510,6 +522,7 @@ defmodule EXGBoost do #{NimbleOptions.docs(@dump_schema)} """ @spec dump_config(Booster.t()) :: binary() + @doc type: :serialization def dump_config(%Booster{} = booster, opts \\ []) do opts = NimbleOptions.validate!(opts, @dump_schema) EXGBoost.Booster.save(booster, opts ++ [serialize: :config, to: :buffer]) @@ -522,6 +535,7 @@ defmodule EXGBoost do #{NimbleOptions.docs(@load_schema)} """ @spec read_config(String.t()) :: EXGBoost.Booster.t() + @doc type: :serialization def read_config(path, opts \\ []) do opts = NimbleOptions.validate!(opts, @load_schema) EXGBoost.Booster.load(path, opts ++ [deserialize: :config]) @@ -534,6 +548,7 @@ defmodule EXGBoost do #{NimbleOptions.docs(@load_schema)} """ @spec load_config(binary()) :: EXGBoost.Booster.t() + @doc type: :serialization def load_config(buffer, opts \\ []) do opts = NimbleOptions.validate!(opts, @load_schema) EXGBoost.Booster.load(buffer, opts ++ [deserialize: :config, from: :buffer]) @@ -546,6 +561,7 @@ defmodule EXGBoost do #{NimbleOptions.docs(@write_schema)} """ @spec write_weights(Booster.t(), String.t()) :: :ok | {:error, String.t()} + @doc type: :serialization def write_weights(%Booster{} = booster, path, opts \\ []) do opts = NimbleOptions.validate!(opts, @write_schema) EXGBoost.Booster.save(booster, opts ++ [path: path, serialize: :weights]) @@ -558,6 +574,7 @@ defmodule EXGBoost do #{NimbleOptions.docs(@dump_schema)} """ @spec dump_weights(Booster.t()) :: binary() + @doc type: :serialization def dump_weights(%Booster{} = booster, opts \\ []) do opts = NimbleOptions.validate!(opts, @dump_schema) EXGBoost.Booster.save(booster, opts ++ [serialize: :weights, to: :buffer]) @@ -567,6 +584,7 @@ defmodule EXGBoost do Read a model's trained parameters from a file and return the Booster. """ @spec read_weights(String.t()) :: EXGBoost.Booster.t() + @doc type: :serialization def read_weights(path) do EXGBoost.Booster.load(path, deserialize: :weights) end @@ -575,6 +593,7 @@ defmodule EXGBoost do Read a model's trained parameters from a buffer and return the Booster. """ @spec load_weights(binary()) :: EXGBoost.Booster.t() + @doc type: :serialization def load_weights(buffer) do EXGBoost.Booster.load(buffer, deserialize: :weights, from: :buffer) end @@ -588,6 +607,7 @@ defmodule EXGBoost do * `:path` - the path to save the graphic to. If not provided, the graphic is returned as a VegaLite spec. * `:opts` - additional options to pass to `EXGBoost.Plotting.plot/2`. See `EXGBoost.Plotting` for more information. """ + @doc type: :plotting def plot_tree(booster, opts \\ []) do {path, opts} = Keyword.pop(opts, :path) {save_opts, opts} = Keyword.split(opts, [:format, :local_npm_prefix]) diff --git a/mix.exs b/mix.exs index 67e29c6..80be283 100644 --- a/mix.exs +++ b/mix.exs @@ -91,6 +91,21 @@ defmodule EXGBoost.MixProject do groups_for_extras: [ Notebooks: Path.wildcard("notebooks/*.livemd") ], + groups_for_functions: [ + "System / Native Config": &(&1[:type] == :system), + "Training & Prediction": &(&1[:type] == :train_pred), + Serialization: &(&1[:type] == :serialization), + Plotting: &(&1[:type] == :plotting) + ], + groups_for_modules: [ + Plotting: [EXGBoost.Plotting, EXGBoost.Plotting.Styles], + Training: [ + EXGBoost.Training, + EXGBoost.Training.Callback, + EXGBoost.Booster, + EXGBoost.Parameters + ] + ], before_closing_body_tag: &before_closing_body_tag/1 ] end diff --git a/notebooks/plotting.livemd b/notebooks/plotting.livemd index 2305d36..88b9b33 100644 --- a/notebooks/plotting.livemd +++ b/notebooks/plotting.livemd @@ -1,19 +1,11 @@ # Plotting in EXGBoost ```elixir -# Mix.install( -# [ -# {:exgboost, "~> 0.5", env: :dev} -# ] -# ) - -Mix.install([ - {:exgboost, path: "/Users/andres/Documents/exgboost"}, - {:nx, "~> 0.5"}, - {:scidata, "~> 0.1"}, - {:scholar, "~> 0.1"}, - {:kino_vega_lite, "~> 0.1"} -]) +Mix.install( + [ + {:exgboost, "~> 0.5", env: :dev} + ] +) # This assumed you launch this livebook from its location in the exgboost/notebooks folder ``` @@ -32,16 +24,16 @@ This notebook will go over some of the details of the `EXGBoost.Plotting` module There are 2 main APIs exposed to control plotting in `EXGBoost`: -* Top-level API (`EXGBoost.plot_tree/2`) +- Top-level API (`EXGBoost.plot_tree/2`) - * Using predefined styles - * Defining custom styles - * Mix of the first 2 + - Using predefined styles + - Defining custom styles + - Mix of the first 2 -* `EXBoost.Plotting` module API +- `EXBoost.Plotting` module API - * Use the Vega `data` spec defined in `EXGBoost.get_data_spec/2` - * Define your own Vega spec using the data from either `EXGBoost.Plotting.to_tabular/1` or some other means + - Use the Vega `data` spec defined in `EXGBoost.Plotting.get_data_spec/2` + - Define your own Vega spec using the data from either `EXGBoost.Plotting.to_tabular/1` or some other means We will walk through each of these in detail. @@ -49,7 +41,7 @@ Regardless of which API you choose to use, it is helpful to understand how the p ## Implementation Details -The plotting functionality provided in `EXGBoost` is powered by the [Vega](https://vega.github.io/vega/) JavaScript library and the Elixir [`VegaLite`](https://hexdocs.pm/vega_lite/VegaLite.html) library which provides the piping to interop with the JavaScript libraries. **We do not actually much use the Elixir API provided by the Elixir VegaLite library. It is mainly used for the purposes of rendering.** +The plotting functionality provided in `EXGBoost` is powered by the [Vega](https://vega.github.io/vega/) JavaScript library and the Elixir [`VegaLite`](https://hexdocs.pm/vega_lite/VegaLite.html) library which provides the piping to interop with the JavaScript libraries. **We do not actually much use the Elixir API provided by the Elixir VegaLite library. It is mainly used for the purposes of rendering.** Vega is a plotting library built on top of the very powerful [D3](https://d3js.org/) JavaScript library. Vega visualizations are defined according to the respective JSON Schema specification. Vega-Lite offers a [reduced schema](https://vega.github.io/schema/vega-lite/v5.json) compared to the [full Vega spec](https://vega.github.io/schema/vega/v5.json). `EXGBoost.Plotting` leverages several transforms which are not available in the reduced Vega-Lite schema, which is the reason for targeting the lower-level API. @@ -77,7 +69,7 @@ y_test = Nx.tensor(y_test) ## Train Your Booster -Now go ahead and train your booster. We will use `early_stopping_rounds: 1` because we're not interested in the accuracy of the booster for this demonstration (*Note that we need to set `evals` to use early stopping*). +Now go ahead and train your booster. We will use `early_stopping_rounds: 1` because we're not interested in the accuracy of the booster for this demonstration (_Note that we need to set `evals` to use early stopping_). You will notice that `EXGBoost` also provides an implementation for `Kino.Render` so that `EXGBoost.Booster`s are rendered as a plot by default. @@ -120,20 +112,20 @@ This API uses [Vega `Mark`s](https://vega.github.io/vega/docs/marks/) to describ The plot is composed of the following parts: -* Top-level keys: Options controlling parts of the plot outside of direct control of a `Mark`, such as `:padding`, `:autosize`, etc. Accepts any Vega top-level [top-level key](https://vega.github.io/vega/docs/specification/) in addition to several specific to this API (scuh as `:style` and `:depth`). -* `:leaves`: `Mark` specifying the leaf nodes of the tree - * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) - * `:rect`: [Rect Mark](https://vega.github.io/vega/docs/marks/rect/) -* `:splits` `Mark` specifying the split (or inner / decision) nodes of the tree - * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) - * `:rect`: [Rect Mark](https://vega.github.io/vega/docs/marks/rect/) - * `:children`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) for the child count -* `:yes` - * `:path`: [Path Mark](https://vega.github.io/vega/docs/marks/path/) - * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) -* `:no` - * `:path`: [Path Mark](https://vega.github.io/vega/docs/marks/path/) - * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) +- Top-level keys: Options controlling parts of the plot outside of direct control of a `Mark`, such as `:padding`, `:autosize`, etc. Accepts any Vega top-level [top-level key](https://vega.github.io/vega/docs/specification/) in addition to several specific to this API (scuh as `:style` and `:depth`). +- `:leaves`: `Mark` specifying the leaf nodes of the tree + - `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) + - `:rect`: [Rect Mark](https://vega.github.io/vega/docs/marks/rect/) +- `:splits` `Mark` specifying the split (or inner / decision) nodes of the tree + - `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) + - `:rect`: [Rect Mark](https://vega.github.io/vega/docs/marks/rect/) + - `:children`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) for the child count +- `:yes` + - `:path`: [Path Mark](https://vega.github.io/vega/docs/marks/path/) + - `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) +- `:no` + - `:path`: [Path Mark](https://vega.github.io/vega/docs/marks/path/) + - `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) `EXGBoost.plot_tree/2` defaults to outputting a `VegaLite` struct. If you pass the `:path` option it will save to a file instead. @@ -348,7 +340,7 @@ For example, if you just want to change the default pre-configured style you can Mix.install([ {:exgboost, path: Path.join(__DIR__, ".."), env: :dev}, ], - config: + config: [ exgboost: [ plotting: [ @@ -363,7 +355,7 @@ You can also make one-off changes to any of the settings with this method. In ef ```elixir - default_style = + default_style = [ style: nil, background: "#3f3f3f", @@ -398,7 +390,7 @@ You can also make one-off changes to any of the settings with this method. In ef Mix.install([ {:exgboost, path: Path.join(__DIR__, ".."), env: :dev}, ], -config: +config: [ exgboost: [ plotting: default_style, @@ -407,7 +399,7 @@ config: ) ``` -**NOTE: When you specify a parameter in the configuration, it is merged with the defaults which is different from runtime behavior.** +**NOTE: When you specify a parameter in the configuration, it is merged with the defaults which is different from runtime behavior.** At any point, you can check what your default settings are by using `EXGBoost.Plotting.get_defaults/0` @@ -445,8 +437,8 @@ EXGBoost.Plotting.get_data_spec(booster, rankdir: :bt) The Vega fields which are not included with `get_data_spec/2` and are included in `plot/2` are: -* [Marks](https://vega.github.io/vega/docs/marks/) -* [Scales](https://vega.github.io/vega/docs/scales/) -* [Signals](https://vega.github.io/vega/docs/signals/) +- [Marks](https://vega.github.io/vega/docs/marks/) +- [Scales](https://vega.github.io/vega/docs/scales/) +- [Signals](https://vega.github.io/vega/docs/signals/) You can make a completely valid plot using only the Data from `get_data_specs/2` and adding the marks you need. From 7a67f967a1ed5eb3be6721d946111cec2f82ab7a Mon Sep 17 00:00:00 2001 From: acalejos Date: Wed, 24 Jan 2024 23:23:56 -0500 Subject: [PATCH 18/21] update runner name --- .github/workflows/precompile.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/precompile.yml b/.github/workflows/precompile.yml index 667aa21..4b6a80e 100644 --- a/.github/workflows/precompile.yml +++ b/.github/workflows/precompile.yml @@ -52,7 +52,7 @@ jobs: MIX_ENV: "prod" strategy: matrix: - runner: ["macos-latest", "self-hosted"] + runner: ["macos-latest", exgboost-m2-runner] otp: ["25.0", "26.0"] elixir: ["1.14.5"] steps: From ac1174e6fdb32d0f39b50e5c06129663bd1f5428 Mon Sep 17 00:00:00 2001 From: acalejos Date: Thu, 25 Jan 2024 21:44:47 -0500 Subject: [PATCH 19/21] update actions workflow --- .github/workflows/precompile.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/precompile.yml b/.github/workflows/precompile.yml index 4b6a80e..72b394e 100644 --- a/.github/workflows/precompile.yml +++ b/.github/workflows/precompile.yml @@ -52,7 +52,7 @@ jobs: MIX_ENV: "prod" strategy: matrix: - runner: ["macos-latest", exgboost-m2-runner] + runner: ["macos-latest", "exgboost-m2-runner"] otp: ["25.0", "26.0"] elixir: ["1.14.5"] steps: From 04a796194dcf2f98dcc8080e007828ead356513a Mon Sep 17 00:00:00 2001 From: acalejos Date: Thu, 25 Jan 2024 21:47:38 -0500 Subject: [PATCH 20/21] update nb --- lib/exgboost/plotting/styles.ex | 7 +++ notebooks/plotting.livemd | 80 ++++++++++++++++++--------------- 2 files changed, 51 insertions(+), 36 deletions(-) diff --git a/lib/exgboost/plotting/styles.ex b/lib/exgboost/plotting/styles.ex index 3e9ede3..2c7a8bd 100644 --- a/lib/exgboost/plotting/styles.ex +++ b/lib/exgboost/plotting/styles.ex @@ -15,4 +15,11 @@ defmodule EXGBoost.Plotting.Styles do """ end) |> Enum.join("\n\n")} """ + + @functions [{:a, "a"}, {:b, "b"}, {:c, "c"}] + for {name, ret} <- @functions do + def unquote(name)() do + unquote(ret) + end + end end diff --git a/notebooks/plotting.livemd b/notebooks/plotting.livemd index 88b9b33..fb4c76c 100644 --- a/notebooks/plotting.livemd +++ b/notebooks/plotting.livemd @@ -1,11 +1,19 @@ # Plotting in EXGBoost ```elixir -Mix.install( - [ - {:exgboost, "~> 0.5", env: :dev} - ] -) +# Mix.install( +# [ +# {:exgboost, "~> 0.5", env: :dev} +# ] +# ) + +Mix.install([ + {:exgboost, path: "/Users/andres/Documents/exgboost"}, + {:nx, "~> 0.5"}, + {:scidata, "~> 0.1"}, + {:scholar, "~> 0.1"}, + {:kino_vega_lite, "~> 0.1"} +]) # This assumed you launch this livebook from its location in the exgboost/notebooks folder ``` @@ -24,16 +32,16 @@ This notebook will go over some of the details of the `EXGBoost.Plotting` module There are 2 main APIs exposed to control plotting in `EXGBoost`: -- Top-level API (`EXGBoost.plot_tree/2`) +* Top-level API (`EXGBoost.plot_tree/2`) - - Using predefined styles - - Defining custom styles - - Mix of the first 2 + * Using predefined styles + * Defining custom styles + * Mix of the first 2 -- `EXBoost.Plotting` module API +* `EXBoost.Plotting` module API - - Use the Vega `data` spec defined in `EXGBoost.Plotting.get_data_spec/2` - - Define your own Vega spec using the data from either `EXGBoost.Plotting.to_tabular/1` or some other means + * Use the Vega `data` spec defined in `EXGBoost.get_data_spec/2` + * Define your own Vega spec using the data from either `EXGBoost.Plotting.to_tabular/1` or some other means We will walk through each of these in detail. @@ -41,7 +49,7 @@ Regardless of which API you choose to use, it is helpful to understand how the p ## Implementation Details -The plotting functionality provided in `EXGBoost` is powered by the [Vega](https://vega.github.io/vega/) JavaScript library and the Elixir [`VegaLite`](https://hexdocs.pm/vega_lite/VegaLite.html) library which provides the piping to interop with the JavaScript libraries. **We do not actually much use the Elixir API provided by the Elixir VegaLite library. It is mainly used for the purposes of rendering.** +The plotting functionality provided in `EXGBoost` is powered by the [Vega](https://vega.github.io/vega/) JavaScript library and the Elixir [`VegaLite`](https://hexdocs.pm/vega_lite/VegaLite.html) library which provides the piping to interop with the JavaScript libraries. **We do not actually much use the Elixir API provided by the Elixir VegaLite library. It is mainly used for the purposes of rendering.** Vega is a plotting library built on top of the very powerful [D3](https://d3js.org/) JavaScript library. Vega visualizations are defined according to the respective JSON Schema specification. Vega-Lite offers a [reduced schema](https://vega.github.io/schema/vega-lite/v5.json) compared to the [full Vega spec](https://vega.github.io/schema/vega/v5.json). `EXGBoost.Plotting` leverages several transforms which are not available in the reduced Vega-Lite schema, which is the reason for targeting the lower-level API. @@ -69,7 +77,7 @@ y_test = Nx.tensor(y_test) ## Train Your Booster -Now go ahead and train your booster. We will use `early_stopping_rounds: 1` because we're not interested in the accuracy of the booster for this demonstration (_Note that we need to set `evals` to use early stopping_). +Now go ahead and train your booster. We will use `early_stopping_rounds: 1` because we're not interested in the accuracy of the booster for this demonstration (*Note that we need to set `evals` to use early stopping*). You will notice that `EXGBoost` also provides an implementation for `Kino.Render` so that `EXGBoost.Booster`s are rendered as a plot by default. @@ -112,20 +120,20 @@ This API uses [Vega `Mark`s](https://vega.github.io/vega/docs/marks/) to describ The plot is composed of the following parts: -- Top-level keys: Options controlling parts of the plot outside of direct control of a `Mark`, such as `:padding`, `:autosize`, etc. Accepts any Vega top-level [top-level key](https://vega.github.io/vega/docs/specification/) in addition to several specific to this API (scuh as `:style` and `:depth`). -- `:leaves`: `Mark` specifying the leaf nodes of the tree - - `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) - - `:rect`: [Rect Mark](https://vega.github.io/vega/docs/marks/rect/) -- `:splits` `Mark` specifying the split (or inner / decision) nodes of the tree - - `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) - - `:rect`: [Rect Mark](https://vega.github.io/vega/docs/marks/rect/) - - `:children`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) for the child count -- `:yes` - - `:path`: [Path Mark](https://vega.github.io/vega/docs/marks/path/) - - `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) -- `:no` - - `:path`: [Path Mark](https://vega.github.io/vega/docs/marks/path/) - - `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) +* Top-level keys: Options controlling parts of the plot outside of direct control of a `Mark`, such as `:padding`, `:autosize`, etc. Accepts any Vega top-level [top-level key](https://vega.github.io/vega/docs/specification/) in addition to several specific to this API (scuh as `:style` and `:depth`). +* `:leaves`: `Mark` specifying the leaf nodes of the tree + * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) + * `:rect`: [Rect Mark](https://vega.github.io/vega/docs/marks/rect/) +* `:splits` `Mark` specifying the split (or inner / decision) nodes of the tree + * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) + * `:rect`: [Rect Mark](https://vega.github.io/vega/docs/marks/rect/) + * `:children`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) for the child count +* `:yes` + * `:path`: [Path Mark](https://vega.github.io/vega/docs/marks/path/) + * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) +* `:no` + * `:path`: [Path Mark](https://vega.github.io/vega/docs/marks/path/) + * `:text`: [Text Mark](https://vega.github.io/vega/docs/marks/text/) `EXGBoost.plot_tree/2` defaults to outputting a `VegaLite` struct. If you pass the `:path` option it will save to a file instead. @@ -146,7 +154,7 @@ EXGBoost.plot_tree(booster, rankdir: :bt) By default, plotting only shows one (the first) tree, but seeing as a `Booster` is really an ensemble of trees you can choose which tree to plot through the `:index` option, or set to `nil` to have a dropdown box to select the tree. ```elixir -EXGBoost.plot_tree(booster, rankdir: :lr, index: 4) +EXGBoost.plot_tree(booster, rankdir: :lr, index: nil) ``` You'll also notice that the plot is interactive, with support for scrolling, zooming, and collapsing sections of the tree. If you click on a split node you will toggle the visibility of its descendents, and the rest of the tree will fill the canvas. @@ -340,7 +348,7 @@ For example, if you just want to change the default pre-configured style you can Mix.install([ {:exgboost, path: Path.join(__DIR__, ".."), env: :dev}, ], - config: + config: [ exgboost: [ plotting: [ @@ -355,7 +363,7 @@ You can also make one-off changes to any of the settings with this method. In ef ```elixir - default_style = + default_style = [ style: nil, background: "#3f3f3f", @@ -390,7 +398,7 @@ You can also make one-off changes to any of the settings with this method. In ef Mix.install([ {:exgboost, path: Path.join(__DIR__, ".."), env: :dev}, ], -config: +config: [ exgboost: [ plotting: default_style, @@ -399,7 +407,7 @@ config: ) ``` -**NOTE: When you specify a parameter in the configuration, it is merged with the defaults which is different from runtime behavior.** +**NOTE: When you specify a parameter in the configuration, it is merged with the defaults which is different from runtime behavior.** At any point, you can check what your default settings are by using `EXGBoost.Plotting.get_defaults/0` @@ -437,8 +445,8 @@ EXGBoost.Plotting.get_data_spec(booster, rankdir: :bt) The Vega fields which are not included with `get_data_spec/2` and are included in `plot/2` are: -- [Marks](https://vega.github.io/vega/docs/marks/) -- [Scales](https://vega.github.io/vega/docs/scales/) -- [Signals](https://vega.github.io/vega/docs/signals/) +* [Marks](https://vega.github.io/vega/docs/marks/) +* [Scales](https://vega.github.io/vega/docs/scales/) +* [Signals](https://vega.github.io/vega/docs/signals/) You can make a completely valid plot using only the Data from `get_data_specs/2` and adding the marks you need. From cf10353ad76878d51b23197773eaffca5b7bed77 Mon Sep 17 00:00:00 2001 From: acalejos Date: Fri, 26 Jan 2024 22:02:08 -0500 Subject: [PATCH 21/21] allow manual actions --- .github/workflows/precompile.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/precompile.yml b/.github/workflows/precompile.yml index 72b394e..ea027e0 100644 --- a/.github/workflows/precompile.yml +++ b/.github/workflows/precompile.yml @@ -1,6 +1,8 @@ name: precompile -on: push +on: + - push + - workflow_dispatch jobs: linux: