Skip to content

Commit

Permalink
Merge pull request #120 from metarank/feature/catboost-options
Browse files Browse the repository at this point in the history
more options to the catboost booster
  • Loading branch information
shuttie authored Feb 9, 2023
2 parents 26b2640 + 38ed15a commit f61ebd1
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,15 @@ object CatboostBooster extends BoosterFactory[String, CatboostBooster, CatboostO
val modelFile = dir.createChild("model.bin")
val opts = Map(
"--learn-set" -> dataset,
"--loss-function" -> "QueryRMSE",
"--loss-function" -> options.objective,
"--eval-metric" -> s"NDCG:top=${options.ndcgCutoff}",
"--iterations" -> options.trees.toString,
"--depth" -> options.maxDepth.toString,
"--learning-rate" -> options.learningRate.toString,
"--train-dir" -> dir.toString(),
"--model-file" -> modelFile.toString()
"--model-file" -> modelFile.toString(),
"--logging-level" -> "Silent",
"--random-seed" -> options.randomSeed.toString
) ++ test.map(t => Map("--test-set" -> t)).getOrElse(Map.empty)
native_impl.ModeFitImpl(new TVector_TString(opts.flatMap(kv => List(kv._1, kv._2)).toArray))
val bytes = modelFile.byteArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ case class CatboostOptions(
learningRate: Double = 0.1,
ndcgCutoff: Int = 10,
maxDepth: Int = 8,
randomSeed: Int = Random.nextInt()
randomSeed: Int = math.abs(Random.nextInt()),
objective: String = "QueryRMSE",
loggingLevel: String = "Verbose"
) extends BoosterOptions
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ case class LightGBMOptions(
learningRate: Double = 0.1,
ndcgCutoff: Int = 10,
maxDepth: Int = 8,
randomSeed: Int = Random.nextInt(),
randomSeed: Int = math.abs(Random.nextInt()),
numLeaves: Int = 16,
featureFraction: Double = 1.0
) extends BoosterOptions
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ case class XGBoostOptions(
learningRate: Double = 0.1,
ndcgCutoff: Int = 10,
maxDepth: Int = 8,
randomSeed: Int = Random.nextInt(),
randomSeed: Int = math.abs(Random.nextInt()),
subsample: Double = 1.0
) extends BoosterOptions

0 comments on commit f61ebd1

Please sign in to comment.