From 4c5672006172475266b665ea7c2a9e9f12018b1e Mon Sep 17 00:00:00 2001 From: acalejos Date: Fri, 19 Jan 2024 23:35:58 -0500 Subject: [PATCH] set feature names in dmatrices --- lib/exgboost/training.ex | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/exgboost/training.ex b/lib/exgboost/training.ex index 8c27b44..3c58715 100644 --- a/lib/exgboost/training.ex +++ b/lib/exgboost/training.ex @@ -6,6 +6,8 @@ defmodule EXGBoost.Training do @spec train(DMatrix.t(), Keyword.t()) :: Booster.t() def train(%DMatrix{} = dmat, opts \\ []) do + dmat_opts = Keyword.take(opts, EXGBoost.Internal.dmatrix_feature_opts()) + valid_opts = [ callbacks: [], early_stopping_rounds: nil, @@ -47,7 +49,7 @@ defmodule EXGBoost.Training do evals_dmats = Enum.map(evals, fn {x, y, name} -> - {DMatrix.from_tensor(x, y, format: :dense), name} + {DMatrix.from_tensor(x, y, Keyword.put_new(dmat_opts, :format, :dense)), name} end) bst = @@ -174,7 +176,7 @@ defmodule EXGBoost.Training do # XGBoost in the Makefile to dump the config to JSON. # %{"learner" => %{"metrics" => metrics, "default_metric" => default_metric}} = - EXGBoost.dump_config(bst) |> Jason.decode!() |> dbg() + EXGBoost.dump_config(bst) |> Jason.decode!() metric_name = cond do