Skip to content

Commit

Permalink
set feature names in dmatrices
Browse files Browse the repository at this point in the history
  • Loading branch information
acalejos committed Jan 20, 2024
1 parent f3c5aae commit 4c56720
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions lib/exgboost/training.ex
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ defmodule EXGBoost.Training do

@spec train(DMatrix.t(), Keyword.t()) :: Booster.t()
def train(%DMatrix{} = dmat, opts \\ []) do
dmat_opts = Keyword.take(opts, EXGBoost.Internal.dmatrix_feature_opts())

valid_opts = [
callbacks: [],
early_stopping_rounds: nil,
Expand Down Expand Up @@ -47,7 +49,7 @@ defmodule EXGBoost.Training do

evals_dmats =
Enum.map(evals, fn {x, y, name} ->
{DMatrix.from_tensor(x, y, format: :dense), name}
{DMatrix.from_tensor(x, y, Keyword.put_new(dmat_opts, :format, :dense)), name}
end)

bst =
Expand Down Expand Up @@ -174,7 +176,7 @@ defmodule EXGBoost.Training do
# XGBoost in the Makefile to dump the config to JSON.
#
%{"learner" => %{"metrics" => metrics, "default_metric" => default_metric}} =
EXGBoost.dump_config(bst) |> Jason.decode!() |> dbg()
EXGBoost.dump_config(bst) |> Jason.decode!()

metric_name =
cond do
Expand Down

0 comments on commit 4c56720

Please sign in to comment.