diff --git a/lib/exgboost/training.ex b/lib/exgboost/training.ex index 8c27b44..3c58715 100644 --- a/lib/exgboost/training.ex +++ b/lib/exgboost/training.ex @@ -6,6 +6,8 @@ defmodule EXGBoost.Training do @spec train(DMatrix.t(), Keyword.t()) :: Booster.t() def train(%DMatrix{} = dmat, opts \\ []) do + dmat_opts = Keyword.take(opts, EXGBoost.Internal.dmatrix_feature_opts()) + valid_opts = [ callbacks: [], early_stopping_rounds: nil, @@ -47,7 +49,7 @@ defmodule EXGBoost.Training do evals_dmats = Enum.map(evals, fn {x, y, name} -> - {DMatrix.from_tensor(x, y, format: :dense), name} + {DMatrix.from_tensor(x, y, Keyword.put_new(dmat_opts, :format, :dense)), name} end) bst = @@ -174,7 +176,7 @@ defmodule EXGBoost.Training do # XGBoost in the Makefile to dump the config to JSON. # %{"learner" => %{"metrics" => metrics, "default_metric" => default_metric}} = - EXGBoost.dump_config(bst) |> Jason.decode!() |> dbg() + EXGBoost.dump_config(bst) |> Jason.decode!() metric_name = cond do