Skip to content

Commit

Permalink
extrnalize train with dmatrix
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Oct 9, 2024
1 parent c7c861b commit adee610
Showing 1 changed file with 40 additions and 33 deletions.
73 changes: 40 additions & 33 deletions src/scicloj/ml/xgboost.clj
Original file line number Diff line number Diff line change
Expand Up @@ -251,17 +251,26 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that
(sparse-feature->dmatrix feature-ds target-ds sparse-column n-sparse-columns)
(dataset->dmatrix feature-ds target-ds)))

(defn- train
[feature-ds label-ds options]


(defn- thaw-model
[model-data]
(-> (if (map? model-data)
(:model-data model-data)
model-data)
(ByteArrayInputStream.)
(XGBoost/loadModel)))



(defn train-from-dmatrix
[train-dmat feature-cnames target-cnames options label-map objective]
;;XGBoost uses all cores so serialization here avoids over subscribing
;;the machine.
(locking #'multiclass-objective?
(let [objective (options->objective options)
(let [
sparse-column-or-nil (:sparse-column options)
train-dmat (->dmatrix feature-ds label-ds sparse-column-or-nil (:n-sparse-columns options))
base-watches (or (:watches options) {})
feature-cnames (ds/column-names feature-ds)
target-cnames (ds/column-names label-ds)
watches (->> base-watches
(reduce (fn [^Map watches [k v]]
(.put watches (ds-utils/column-safe-name k)
Expand All @@ -271,7 +280,7 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that
sparse-column-or-nil
(:n-sparse-columns options)))
watches)
;;Linked hash map to preserve order
;;Linked hash map to preserve order
(LinkedHashMap.)))
round (or (:round options) 25)
early-stopping-round (or (when (:early-stopping-round options)
Expand All @@ -281,46 +290,41 @@ subsample may be set to as low as 0.1 without loss of model accuracy. Note that
(not (instance? LinkedHashMap (:watches options)))
(not= 0 early-stopping-round))
(log/warn "Early stopping indicated but watches has undefined iteration order.
Early stopping will always use the 'last' of the watches as defined by the iteration
order of the watches map. Consider using a java.util.LinkedHashMap for watches.
https://github.com/dmlc/xgboost/blob/master/jvm-packages/xgboost4j/src/main/java/ml/dml
c/xgboost4j/java/XGBoost.java#L208"))
Early stopping will always use the 'last' of the watches as defined by the iteration
order of the watches map. Consider using a java.util.LinkedHashMap for watches.
https://github.com/dmlc/xgboost/blob/master/jvm-packages/xgboost4j/src/main/java/ml/dml
c/xgboost4j/java/XGBoost.java#L208"))
watch-names (->> base-watches
(map-indexed (fn [idx [k v]]
[idx k]))
(into {}))
label-map (when (multiclass-objective? objective)
(ds-mod/inference-target-label-map label-ds))
cleaned-options
cleaned-options
(->
(dissoc options :model-type :watches)
(assoc :objective objective))
params (->> cleaned-options
;;Adding in some defaults
;;Adding in some defaults
(merge
{
:alpha 0.0
:eta 0.3
:lambda 1.0
:max-depth 6
:subsample 0.87

}

{:alpha 0.0
:eta 0.3
:lambda 1.0
:max-depth 6
:subsample 0.87}

cleaned-options
(when label-map
{:num-class (count label-map)}))
(map (fn [[k v]]
(when v
[(s/replace (name k) "-" "_" ) v])))
[(s/replace (name k) "-" "_") v])))

(remove nil?)
(into {}))
^"[[F" metrics-data (when-not (empty? watches)
(->> (repeatedly (count watches)
#(float-array round))
(into-array)))

_ (println :params params)
^Booster model (XGBoost/train train-dmat params
(long round)
(or watches {}) metrics-data nil nil
Expand All @@ -340,13 +344,16 @@ c/xgboost4j/java/XGBoost.java#L208"))
(ds/->>dataset {:dataset-name :metrics}))})))))


(defn- thaw-model
[model-data]
(-> (if (map? model-data)
(:model-data model-data)
model-data)
(ByteArrayInputStream.)
(XGBoost/loadModel)))
(defn train [feature-ds label-ds options]
(let [sparse-column-or-nil (:sparse-column options)
feature-cnames (ds/column-names feature-ds)
target-cnames (ds/column-names label-ds)
train-dmat (->dmatrix feature-ds label-ds sparse-column-or-nil (:n-sparse-columns options))
objective (options->objective options)

label-map (when (multiclass-objective? objective)
(ds-mod/inference-target-label-map label-ds))]
(train-from-dmatrix train-dmat feature-cnames target-cnames options label-map objective)))


(defn- predict
Expand Down

0 comments on commit adee610

Please sign in to comment.