Skip to content

Commit

Permalink
WIP plotting notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
acalejos committed Jan 24, 2024
1 parent ead27a3 commit d80378c
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 47 deletions.
65 changes: 31 additions & 34 deletions lib/exgboost/plotting.ex
Original file line number Diff line number Diff line change
Expand Up @@ -413,69 +413,58 @@ defmodule EXGBoost.Plotting do
align: :center,
baseline: :middle,
font_size: 13,
font: "Calibri",
fill: :black
font: "Calibri"
]

@default_leaf_rect [
corner_radius: 2,
fill: :teal,
opacity: 1
]

@default_split_text [
align: :center,
baseline: :middle,
font_size: 13,
font: "Calibri",
fill: :black
font: "Calibri"
]

@default_split_rect [
corner_radius: 2,
fill: :teal,
opacity: 1
]

@default_split_children [
align: :right,
baseline: :middle,
fill: :black,
font: "Calibri",
font_size: 13
]

@default_yes_path [
stroke: :red
]
@default_yes_path []

@default_yes_text [
align: :center,
baseline: :middle,
font_size: 13,
font: "Calibri",
fill: :black,
text: "yes"
]

@default_no_path [
stroke: :black
]
@default_no_path []

@default_no_text [
align: :center,
baseline: :middle,
font_size: 13,
font: "Calibri",
fill: :black,
text: "no"
]

@plotting_params [
style: [
doc:
"The style to use for the visualization. Refer to `EXGBoost.Plotting.Styles` for a list of available styles.",
default: Application.compile_env(:exgboost, [:plotting, :style], :solarized_light),
default: Application.compile_env(:exgboost, [:plotting, :style], :dracula),
type: {:or, [{:in, Keyword.keys(@styles)}, {:in, [nil, false]}, :keyword_list]}
],
rankdir: [
Expand Down Expand Up @@ -948,12 +937,13 @@ defmodule EXGBoost.Plotting do
opts

Keyword.keyword?(opts[:style]) ->
deep_merge_kw(opts, opts[:style])
deep_merge_kw(opts[:style], opts, @defaults)

true ->
style = apply(__MODULE__, opts[:style], [])
deep_merge_kw(opts, style)
deep_merge_kw(style, opts, @defaults)
end
|> then(&deep_merge_kw(@defaults, &1))

%{
"$schema" => "https://vega.github.io/schema/vega/v5.json",
Expand Down Expand Up @@ -1278,32 +1268,31 @@ defmodule EXGBoost.Plotting do
opts

Keyword.keyword?(opts[:style]) ->
deep_merge_kw(opts, opts[:style])
deep_merge_kw(opts[:style], opts, @defaults)

true ->
style = apply(__MODULE__, opts[:style], [])
deep_merge_kw(opts, style)
deep_merge_kw(style, opts, @defaults)
end

defaults = get_defaults()
|> then(&deep_merge_kw(@defaults, &1))

# Try to account for non-default node height / width to adjust spacing
# between nodes and levels as a quality of life improvement
# If they provide their own spacing, don't adjust it
opts =
cond do
opts[:rankdir] in [:lr, :rl] and
opts[:space_between][:levels] == defaults[:space_between][:levels] ->
opts[:space_between][:levels] == @defaults[:space_between][:levels] ->
put_in(
opts[:space_between][:levels],
defaults[:space_between][:levels] + (opts[:node_width] - defaults[:node_width])
@defaults[:space_between][:levels] + (opts[:node_width] - @defaults[:node_width])
)

opts[:rankdir] in [:tb, :bt] and
opts[:space_between][:levels] == defaults[:space_between][:levels] ->
opts[:space_between][:levels] == @defaults[:space_between][:levels] ->
put_in(
opts[:space_between][:levels],
defaults[:space_between][:levels] + (opts[:node_height] - defaults[:node_height])
@defaults[:space_between][:levels] + (opts[:node_height] - @defaults[:node_height])
)

true ->
Expand All @@ -1313,17 +1302,17 @@ defmodule EXGBoost.Plotting do
opts =
cond do
opts[:rankdir] in [:lr, :rl] and
opts[:space_between][:nodes] == defaults[:space_between][:nodes] ->
opts[:space_between][:nodes] == @defaults[:space_between][:nodes] ->
put_in(
opts[:space_between][:nodes],
defaults[:space_between][:nodes] + (opts[:node_height] - defaults[:node_height])
@defaults[:space_between][:nodes] + (opts[:node_height] - @defaults[:node_height])
)

opts[:rankdir] in [:tb, :bt] and
opts[:space_between][:nodes] == defaults[:space_between][:nodes] ->
opts[:space_between][:nodes] == @defaults[:space_between][:nodes] ->
put_in(
opts[:space_between][:nodes],
defaults[:space_between][:nodes] + (opts[:node_width] - defaults[:node_width])
@defaults[:space_between][:nodes] + (opts[:node_width] - @defaults[:node_width])
)

true ->
Expand Down Expand Up @@ -1381,15 +1370,19 @@ defmodule EXGBoost.Plotting do
%{
"encode" => %{
"update" =>
Map.merge(
deep_merge_maps(
%{
"x" => %{
"signal" =>
"(scale('xscale', datum.source.x+(nodeWidth/3)) + scale('xscale', datum.target.x)) / 2"
"(scale('xscale', datum.source.x#{cond do
opts[:rankdir] in [:tb, :bt, :lr] -> ~c"-nodeWidth/4"
opts[:rankdir] in [:rl] -> ~c"+nodeWidth/4"
true -> ~c""
end}) + scale('xscale', datum.target.x)) / 2"
},
"y" => %{
"signal" =>
"(scale('yscale', datum.source.y) + scale('yscale', datum.target.y)) / 2 - (scaledNodeHeight/2)"
"(scale('yscale', datum.source.y#{if opts[:rankdir] in [:lr, :rl], do: ~c"-nodeWidth/3", else: ~c""}) + scale('yscale', datum.target.y)) / 2 - (scaledNodeHeight/2)"
}
},
format_mark(opts[:yes][:text])
Expand All @@ -1405,7 +1398,11 @@ defmodule EXGBoost.Plotting do
%{
"x" => %{
"signal" =>
"(scale('xscale', datum.source.x+(nodeWidth/3)) + scale('xscale', datum.target.x)) / 2"
"(scale('xscale', datum.source.x#{cond do
opts[:rankdir] in [:tb, :bt, :lr] -> ~c"-nodeWidth/4"
opts[:rankdir] in [:rl] -> ~c"+nodeWidth/4"
true -> ~c""
end}) + scale('xscale', datum.target.x)) / 2"
},
"y" => %{
"signal" =>
Expand Down
20 changes: 19 additions & 1 deletion lib/exgboost/plotting/style.ex
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,31 @@ defmodule EXGBoost.Plotting.Style do
end
end

def deep_merge_kw(a, b) do
def deep_merge_kw(a, b, ignore_set \\ []) do
Keyword.merge(a, b, fn
_key, val_a, val_b when is_list(val_a) and is_list(val_b) ->
deep_merge_kw(val_a, val_b)

key, val_a, val_b ->
if Keyword.has_key?(b, key) do
if Keyword.has_key?(ignore_set, key) and Keyword.get(ignore_set, key) == val_b do
val_a
else
val_b
end
else
val_a
end
end)
end

def deep_merge_maps(b, a) do
Map.merge(a, b, fn
_key, val_a, val_b when is_map(val_a) and is_map(val_b) ->
deep_merge_maps(val_a, val_b)

key, val_a, val_b ->
if Map.has_key?(b, key) do
val_b
else
val_a
Expand Down
Loading

0 comments on commit d80378c

Please sign in to comment.