From 6ef352639cfcf0f2d8c9c3ff4753c6df562c4a4f Mon Sep 17 00:00:00 2001 From: Essam Date: Thu, 5 Oct 2023 02:42:15 +0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=91=20Fix=20compatibility=20issue?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- example/BalancedBagging.ipynb | 55 ++++---- example/BalancedModel.ipynb | 238 +++++++++++++++++++++++++--------- example/Manifest.toml | 48 ++++--- src/balanced_bagging.jl | 4 +- test/balanced_bagging.jl | 12 +- test/balanced_model.jl | 2 +- 6 files changed, 245 insertions(+), 114 deletions(-) diff --git a/example/BalancedBagging.ipynb b/example/BalancedBagging.ipynb index efe54f4..5000a41 100644 --- a/example/BalancedBagging.ipynb +++ b/example/BalancedBagging.ipynb @@ -40,7 +40,7 @@ { "data": { "text/plain": [ - "((Column1 = [0.9695150609084499, 0.012898301755861596, 0.7555027304121053, 0.3467415729179013, 0.35969402837473463, 0.2601876747805505, 0.9522580699968279, 0.06304475092339623, 0.18909001622655808, 0.19934942931986965 … 0.021532597906190776, 0.8482825697641306, 0.10773487816863903, 0.32189982199036116, 0.12662208474317038, 0.28529465447429614, 0.2907506630258835, 0.36872799387588473, 0.061489791166806085, 0.45645058368583713], Column2 = [0.06546916714160167, 0.7243956502957003, 0.5183099801474415, 0.7555562860508294, 0.11226218114407538, 0.9135150277876691, 0.8739421974558176, 0.2268482788660101, 0.580604436651146, 0.4142252330250549 … 0.6517425913240111, 0.01713263102740481, 0.7175499403837856, 0.7362894157420817, 0.24893665902538054, 0.41499951381631595, 0.2159527717429719, 0.8966879835264249, 0.87252430655793, 0.41461921031276117], Column3 = [0.5939320702328891, 0.19329886972497456, 0.04656947038518311, 0.22095698685781184, 0.678807659662497, 0.12720198818430306, 0.6795750371448686, 0.9314917999820301, 0.22920734893984274, 0.5148148980955375 … 0.55049773593343, 0.038576459283091946, 0.27765727942909757, 0.2753072414696357, 0.8823620780359746, 0.44831794170895023, 0.9073846432163745, 0.4648550947905655, 0.311984726769037, 0.25829997798611304], Column4 = [0.12253944650540982, 0.8259140842535423, 0.4034477332184384, 0.5279399406265695, 0.5579944087437719, 0.24650366028608328, 0.6874897000162434, 0.23391406844015605, 0.5641254897013973, 0.6250622796341656 … 0.21708181942178983, 0.35224683896541464, 0.8444113778983325, 0.4547214584884428, 0.13508852017592232, 0.9510137735662383, 0.5723463533029658, 0.626377972762265, 0.7854013810594317, 0.15394691114473347], Column5 = [0.47958743625921163, 0.45779753417165514, 0.6367059235247621, 0.8601116026079643, 0.3334020182022719, 0.41593698717526373, 0.13208968772625174, 0.16951044109747648, 0.8137887839507706, 0.4429229861115882 … 0.01308976221980429, 0.48597926808091163, 0.20768781798463476, 0.30045611276046247, 0.15759293576302558, 0.975806377881983, 0.19451065500145392, 0.9638103356367584, 0.3594043445295293, 0.7792867217495332], Column6 = [3.0, 3.0, 1.0, 3.0, 1.0, 2.0, 3.0, 2.0, 3.0, 3.0 … 3.0, 2.0, 1.0, 2.0, 1.0, 2.0, 2.0, 3.0, 3.0, 1.0], Column7 = [2.0, 2.0, 2.0, 2.0, 1.0, 2.0, 2.0, 2.0, 1.0, 1.0 … 2.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 1.0, 1.0]), CategoricalArrays.CategoricalValue{Int64, UInt32}[0, 0, 0, 0, 0, 0, 0, 0, 1, 0 … 0, 0, 1, 0, 1, 0, 0, 0, 0, 0])" + "((Column1 = [0.564, 0.862, 0.793, 0.505, 0.683, 0.699, 0.545, 0.693, 0.95, 0.44 … 0.423, 0.632, 0.922, 0.592, 0.944, 0.517, 0.785, 0.579, 0.725, 0.711], Column2 = [0.42, 0.715, 0.358, -0.009, 0.228, 0.725, 0.786, 0.52, 0.646, 0.582 … 0.65, 0.633, 0.263, 0.141, 0.472, 0.45, -0.019, 0.593, 0.777, 0.877], Column3 = [0.638, 0.719, 0.716, 0.604, 0.616, 0.784, 0.697, 0.711, 0.878, 0.739 … 0.722, 0.672, 0.879, 0.598, 0.879, 0.669, 0.728, 0.768, 0.736, 0.725], Column4 = [0.29, 0.164, 0.164, 0.262, 0.246, 0.211, 0.155, 0.03, 1.842, 0.324 … 0.192, 0.143, 1.323, 0.251, 1.084, 0.165, 0.138, 0.176, 0.155, 0.217], Column5 = [0.605, 0.287, 0.565, 0.121, 0.752, 0.317, 0.165, 0.497, 0.361, 0.293 … 0.726, 0.781, 0.694, 0.728, 0.692, 0.351, 0.089, 0.478, 0.067, -0.19], Column6 = [2.0, 1.0, 3.0, 1.0, 3.0, 1.0, 3.0, 2.0, 2.0, 3.0 … 1.0, 3.0, 2.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0], Column7 = [2.0, 2.0, 1.0, 2.0, 1.0, 1.0, 2.0, 2.0, 1.0, 2.0 … 1.0, 2.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 1.0, 1.0]), CategoricalArrays.CategoricalValue{Int64, UInt32}[0, 0, 0, 0, 0, 0, 0, 0, 1, 0 … 0, 0, 1, 0, 1, 0, 0, 0, 0, 0])" ] }, "metadata": {}, @@ -48,8 +48,8 @@ } ], "source": [ - "X, y = generate_imbalanced_data(100, 5; cat_feats_num_vals = [3, 2], \n", - " probs = [0.9, 0.1], \n", + "X, y = generate_imbalanced_data(100, 5; num_vals_per_category = [3, 2], \n", + " class_probs = [0.9, 0.1], \n", " type = \"ColTable\", \n", " rng=42)" ] @@ -73,6 +73,15 @@ "WARNING: using StaticArrays.setindex in module FiniteDiff conflicts with an existing identifier.\n" ] }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Warning: The call to compilecache failed to create a usable precompiled cache file for MLJLinearModels [6ee0df7b-362f-4a72-a706-9e79364fb692]\n", + "│ exception = ErrorException(\"Required dependency Optim [429524aa-4258-5aef-a3af-852621145aeb] failed to load from a cache file.\")\n", + "└ @ Base loading.jl:1349\n" + ] + }, { "data": { "text/plain": [ @@ -108,7 +117,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -127,26 +136,26 @@ "data": { "text/plain": [ "100-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{2}, Int64, UInt32, Float64}:\n", - " UnivariateFinite{Multiclass{2}}(0=>0.928, 1=>0.0722)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.845, 1=>0.155)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.749, 1=>0.251)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.902, 1=>0.0977)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.804, 1=>0.196)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.864, 1=>0.136)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.851, 1=>0.149)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.954, 1=>0.0458)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.853, 1=>0.147)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.86, 1=>0.14)\n", + " UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.0, 1=>1.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n", " ⋮\n", - " UnivariateFinite{Multiclass{2}}(0=>0.671, 1=>0.329)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.73, 1=>0.27)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.843, 1=>0.157)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.941, 1=>0.0594)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.872, 1=>0.128)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.92, 1=>0.0797)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.929, 1=>0.0714)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.791, 1=>0.209)\n", - " UnivariateFinite{Multiclass{2}}(0=>0.827, 1=>0.173)" + " UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.0, 1=>1.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>0.0, 1=>1.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)\n", + " UnivariateFinite{Multiclass{2}}(0=>1.0, 1=>0.0)" ] }, "metadata": {}, diff --git a/example/BalancedModel.ipynb b/example/BalancedModel.ipynb index a168cf8..93eb1e8 100644 --- a/example/BalancedModel.ipynb +++ b/example/BalancedModel.ipynb @@ -4,7 +4,15 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m project at `~/Documents/GitHub/MLJBalancing/example`\n" + ] + } + ], "source": [ "ENV[\"JULIA_PKG_SERVER\"] = \"\"\n", "using Pkg\n", @@ -34,14 +42,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 204 (40.9%) \n", - "1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 297 (59.5%) \n", - "2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 499 (100.0%) \n" + "0: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 189 (37.4%) \n", + "1: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 305 (60.3%) \n", + "2: ▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇ 506 (100.0%) \n" ] } ], "source": [ - "X, y = Imbalance.generate_imbalanced_data(1000, 5; probs=[0.2, 0.3, 0.5])\n", + "X, y = Imbalance.generate_imbalanced_data(1000, 5; class_probs=[0.2, 0.3, 0.5])\n", "X = DataFrame(X)\n", "(X_train, X_test), (y_train, y_test) = partition((X, y), 0.8, rng=123, multi=true)\n", "Imbalance.checkbalance(y)" @@ -144,6 +152,7 @@ " balancer2 = SMOTENC(\n", " k = 10, \n", " ratios = 1.2, \n", + " knn_tree = \"Brute\", \n", " rng = 42, \n", " try_perserve_type = true), \n", " balancer3 = ROSE(\n", @@ -173,9 +182,148 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Training machine(BalancedModelProbabilistic(model = LogisticClassifier(lambda = 2.220446049250313e-16, …), …), …).\n", + "└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Training machine(ROSE(s = 1.0, …), …).\n", + "└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n", + "┌ Info: Training machine(SMOTENC(k = 10, …), …).\n", + "└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n", + "┌ Info: Training machine(RandomOversampler(ratios = 1.0, …), …).\n", + "└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n", + "┌ Info: Training machine(:model, …).\n", + "└ @ MLJBase /Users/essam/.julia/packages/MLJBase/ByFwA/src/machines.jl:492\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r\u001b[32mProgress: 67%|███████████████████████████▍ | ETA: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 0\u001b[39m\u001b[K\r\u001b[A" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r\u001b[32mProgress: 67%|███████████████████████████▍ | ETA: 0:00:01\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 1\u001b[39m\u001b[K\r\u001b[A" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "\r\u001b[K\u001b[A\r\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:01\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 0\u001b[39m\u001b[K\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r\u001b[32mProgress: 67%|███████████████████████████▍ | ETA: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 1\u001b[39m\u001b[K\r\u001b[A" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "\r\u001b[K\u001b[A\r\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 0\u001b[39m\u001b[K\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r\u001b[32mProgress: 67%|███████████████████████████▍ | ETA: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 1\u001b[39m\u001b[K\r\u001b[A" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "\r\u001b[K\u001b[A\r\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 0\u001b[39m\u001b[K\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r\u001b[32mProgress: 67%|███████████████████████████▍ | ETA: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 1\u001b[39m\u001b[K\r\u001b[A" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "\r\u001b[K\u001b[A\r\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 0\u001b[39m\u001b[K\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r\u001b[32mProgress: 67%|███████████████████████████▍ | ETA: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 1\u001b[39m\u001b[K\r\u001b[A" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "\r\u001b[K\u001b[A\r\u001b[32mProgress: 100%|█████████████████████████████████████████| Time: 0:00:00\u001b[39m\u001b[K\r\n", + "\u001b[34m class: 0\u001b[39m\u001b[K\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "┌ Info: Solver: MLJLinearModels.LBFGS{Optim.Options{Float64, Nothing}, NamedTuple{(), Tuple{}}}\n", + "│ optim_options: Optim.Options{Float64, Nothing}\n", + "│ lbfgs_options: NamedTuple{(), Tuple{}} NamedTuple()\n", + "└ @ MLJLinearModels /Users/essam/.julia/packages/MLJLinearModels/zSQnL/src/mlj/interface.jl:72\n" + ] + }, + { + "data": { + "text/plain": [ + "trained Machine; does not cache data\n", + " model: BalancedModelProbabilistic(model = LogisticClassifier(lambda = 2.220446049250313e-16, …), …)\n", + " args: \n", + " 1:\tSource @226 ⏎ Table{AbstractVector{Continuous}}\n", + " 2:\tSource @078 ⏎ AbstractVector{Multiclass{3}}\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "mach = machine(balanced_model, X_train, y_train)\n", "fit!(mach)" @@ -183,33 +331,33 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "200-element CategoricalDistributions.UnivariateFiniteVector{Multiclass{3}, Int64, UInt32, Float64}:\n", - " UnivariateFinite{Multiclass{3}}(0=>0.359, 1=>0.295, 2=>0.346)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.384, 1=>0.294, 2=>0.322)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.301, 1=>0.395, 2=>0.304)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.285, 1=>0.369, 2=>0.346)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.279, 1=>0.39, 2=>0.331)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.31, 1=>0.34, 2=>0.35)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.292, 1=>0.392, 2=>0.316)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.331, 1=>0.351, 2=>0.318)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.303, 1=>0.35, 2=>0.347)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.311, 1=>0.351, 2=>0.338)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.0, 1=>1.0, 2=>4.16e-270)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.0, 1=>1.2e-217, 2=>1.0)\n", + " UnivariateFinite{Multiclass{3}}(0=>2.99e-304, 1=>1.0, 2=>1.19e-221)\n", + " UnivariateFinite{Multiclass{3}}(0=>1.0, 1=>1.35e-179, 2=>2.0900000000000003e-267)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.0, 1=>1.36e-93, 2=>1.0)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.0, 1=>4.01e-71, 2=>1.0)\n", + " UnivariateFinite{Multiclass{3}}(0=>1.16e-270, 1=>4.55e-103, 2=>1.0)\n", + " UnivariateFinite{Multiclass{3}}(0=>1.0, 1=>1.0299999999999999e-198, 2=>0.0)\n", + " UnivariateFinite{Multiclass{3}}(0=>1.0, 1=>2.2100000000000002e-73, 2=>1.45e-97)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.0, 1=>3.4900000000000003e-75, 2=>1.0)\n", " ⋮\n", - " UnivariateFinite{Multiclass{3}}(0=>0.319, 1=>0.354, 2=>0.326)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.375, 1=>0.291, 2=>0.334)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.345, 1=>0.329, 2=>0.326)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.312, 1=>0.343, 2=>0.345)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.358, 1=>0.308, 2=>0.333)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.307, 1=>0.344, 2=>0.349)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.297, 1=>0.36, 2=>0.343)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.358, 1=>0.312, 2=>0.33)\n", - " UnivariateFinite{Multiclass{3}}(0=>0.355, 1=>0.309, 2=>0.336)" + " UnivariateFinite{Multiclass{3}}(0=>1.3699999999999999e-239, 1=>9.34e-140, 2=>1.0)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.0, 1=>1.0, 2=>2.3599999999999997e-256)\n", + " UnivariateFinite{Multiclass{3}}(0=>3.03e-149, 1=>1.69e-109, 2=>1.0)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.0, 1=>1.0, 2=>3.3999999999999996e-242)\n", + " UnivariateFinite{Multiclass{3}}(0=>8.889999999999998e-259, 1=>8.98e-152, 2=>1.0)\n", + " UnivariateFinite{Multiclass{3}}(0=>1.5500000000000002e-235, 1=>7.45e-95, 2=>1.0)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.0, 1=>1.0, 2=>4.3e-232)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.0, 1=>2.31e-134, 2=>1.0)\n", + " UnivariateFinite{Multiclass{3}}(0=>0.0, 1=>1.0, 2=>1.21e-194)" ] }, "metadata": {}, @@ -250,41 +398,9 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "BalancedModelProbabilistic(\n", - " model = LogisticClassifier(\n", - " lambda = 2.220446049250313e-16, \n", - " gamma = 0.0, \n", - " penalty = :l2, \n", - " fit_intercept = true, \n", - " penalize_intercept = false, \n", - " scale_penalty_with_samples = true, \n", - " solver = nothing), \n", - " balancer1 = RandomOversampler(\n", - " ratios = 1.4, \n", - " rng = 42, \n", - " try_perserve_type = true), \n", - " balancer2 = SMOTENC(\n", - " k = 10, \n", - " ratios = 1.2, \n", - " rng = 42, \n", - " try_perserve_type = true), \n", - " balancer3 = ROSE(\n", - " s = 0.0, \n", - " ratios = 1.3, \n", - " rng = 42, \n", - " try_perserve_type = true))" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], + "outputs": [], "source": [ "fitted_params(mach).best_model" ] diff --git a/example/Manifest.toml b/example/Manifest.toml index e9a05c0..789b316 100644 --- a/example/Manifest.toml +++ b/example/Manifest.toml @@ -99,6 +99,12 @@ git-tree-sha1 = "2fba81a302a7be671aefe194f0525ef231104e7f" uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" version = "0.1.8" +[[deps.Clustering]] +deps = ["Distances", "LinearAlgebra", "NearestNeighbors", "Printf", "Random", "SparseArrays", "Statistics", "StatsBase"] +git-tree-sha1 = "b86ac2c5543660d238957dbde5ac04520ae977a7" +uuid = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5" +version = "0.15.4" + [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] git-tree-sha1 = "02aa26a4cf76381be7f66e020a3eddeb27b0a092" @@ -236,9 +242,9 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] deps = ["ChainRulesCore", "DensityInterface", "FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "9e11104e7b41a8a5f04e8694467fc1f94a135bd7" +git-tree-sha1 = "3d5873f811f582873bb9871fc9c451784d5dc8c7" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.101" +version = "0.25.102" [[deps.DocStringExtensions]] deps = ["LibGit2"] @@ -283,9 +289,9 @@ version = "0.1.1" [[deps.FilePathsBase]] deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] -git-tree-sha1 = "e27c4ebe80e8699540f2d6c805cc12203b614f12" +git-tree-sha1 = "9f00e42f8d99fdde64d40c8ea5d14269a2e2c1aa" uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.20" +version = "0.9.21" [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" @@ -330,10 +336,10 @@ uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" version = "0.3.23" [[deps.Imbalance]] -deps = ["CategoricalArrays", "CategoricalDistributions", "Distances", "LinearAlgebra", "MLJModelInterface", "MLJTestInterface", "NearestNeighbors", "OrderedCollections", "ProgressMeter", "Random", "ScientificTypes", "Statistics", "StatsBase", "TableOperations", "TableTransforms", "Tables", "TransformsBase"] -git-tree-sha1 = "53eeb73d88913134cab0b0e04dd58901769fc7db" +deps = ["CategoricalArrays", "CategoricalDistributions", "Clustering", "Distances", "LinearAlgebra", "MLJModelInterface", "MLJTestInterface", "NearestNeighbors", "OrderedCollections", "ProgressMeter", "Random", "ScientificTypes", "Statistics", "StatsBase", "TableOperations", "TableTransforms", "Tables", "TransformsBase"] +git-tree-sha1 = "2cdff31d45d48b8001420dc1f6f20a36ad98dd8a" uuid = "c709b415-507b-45b7-9a3d-1767c89fde68" -version = "0.1.0" +version = "0.1.1" [[deps.InitialValues]] git-tree-sha1 = "4da0f88e9a39111c2fa3add390ab15f3a44f3ca3" @@ -385,9 +391,9 @@ version = "0.21.4" [[deps.JuliaFormatter]] deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "Pkg", "PrecompileTools", "Tokenize"] -git-tree-sha1 = "2aa8cb5410821365a87f326631d7f6ce07db8882" +git-tree-sha1 = "80031f6e58b09b0de4553bf63d9a36ec5db57967" uuid = "98e50ef6-434e-11e9-1051-2b60c6c9e899" -version = "1.0.36" +version = "1.0.39" [[deps.JuliaVariables]] deps = ["MLStyle", "NameResolution"] @@ -403,15 +409,15 @@ version = "0.9.8" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "a9d2ce1d5007b1e8f6c5b89c5a31ff8bd146db5c" +git-tree-sha1 = "4ea2928a96acfcf8589e6cd1429eff2a3a82c366" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "6.2.1" +version = "6.3.0" [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "7ca6850ae880cc99b59b88517545f91a52020afa" +git-tree-sha1 = "e7c01b69bcbcb93fd4cbc3d0fea7d229541e18d2" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.25+0" +version = "0.0.26+0" [[deps.LaTeXStrings]] git-tree-sha1 = "f2355693d6778a178ade15952b7ac47a4ff97996" @@ -465,9 +471,9 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" [[deps.LoggingExtras]] deps = ["Dates", "Logging"] -git-tree-sha1 = "0d097476b6c381ab7906460ef1ef1638fbce1d91" +git-tree-sha1 = "c1dd6d7978c12545b4179fb6153b9250c96b0075" uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "1.0.2" +version = "1.0.3" [[deps.LossFunctions]] deps = ["Markdown", "Requires", "Statistics"] @@ -533,9 +539,9 @@ version = "0.16.11" [[deps.MLJTestInterface]] deps = ["MLJBase", "Pkg", "Test"] -git-tree-sha1 = "9131806695e6a6d32c61ed5f7bccaadef9fef57e" +git-tree-sha1 = "0a4167e43dcd96ad293fe0bded5923703c169553" uuid = "72560011-54dd-4dc2-94f3-c5de45b75ecd" -version = "0.2.2" +version = "0.2.3" [[deps.MLJTuning]] deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase"] @@ -674,9 +680,9 @@ version = "1.6.2" [[deps.PDMats]] deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "bf6085e8bd7735e68c210c6e5d81f9a6fe192060" +git-tree-sha1 = "528664265c9c36b3ecdb6d721d47aaab52ddf267" uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.19" +version = "0.11.24" [[deps.Parameters]] deps = ["OrderedCollections", "UnPack"] @@ -867,9 +873,9 @@ version = "1.0.0" [[deps.StaticArrays]] deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] -git-tree-sha1 = "d5fb407ec3179063214bc6277712928ba78459e2" +git-tree-sha1 = "0adf069a2a490c47273727e029371b31d44b72b2" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.6.4" +version = "1.6.5" [[deps.StaticArraysCore]] git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" diff --git a/src/balanced_bagging.jl b/src/balanced_bagging.jl index ec7496f..4137077 100644 --- a/src/balanced_bagging.jl +++ b/src/balanced_bagging.jl @@ -234,8 +234,8 @@ logistic_model = LogisticClassifier() model = BalancedBaggingClassifier(model=logistic_model, T=5) # Load the data and train the BalancedBaggingClassifier -X, y = Imbalance.generate_imbalanced_data(100, 5; cat_feats_num_vals = [3, 2], - probs = [0.9, 0.1], +X, y = Imbalance.generate_imbalanced_data(100, 5; num_vals_per_category = [3, 2], + class_probs = [0.9, 0.1], type = "ColTable", rng=42) julia> Imbalance.checkbalance(y) diff --git a/test/balanced_bagging.jl b/test/balanced_bagging.jl index ddd6252..678192d 100644 --- a/test/balanced_bagging.jl +++ b/test/balanced_bagging.jl @@ -14,8 +14,8 @@ end X, y = generate_imbalanced_data( 100, 5; - cat_feats_num_vals = [3, 2, 1, 2], - probs = [0.9, 0.1], + num_vals_per_category = [3, 2, 1, 2], + class_probs = [0.9, 0.1], type = "ColTable", rng = 42, ) @@ -60,8 +60,8 @@ end X, y = generate_imbalanced_data( 100, 5; - cat_feats_num_vals = [3, 2, 1, 2], - probs = [0.9, 0.1], + num_vals_per_category = [3, 2, 1, 2], + class_probs = [0.9, 0.1], type = "ColTable", rng = 42, ) @@ -69,8 +69,8 @@ end Xt, yt = generate_imbalanced_data( 5, 5; - cat_feats_num_vals = [3, 2, 1, 2], - probs = [0.9, 0.1], + num_vals_per_category = [3, 2, 1, 2], + class_probs = [0.9, 0.1], type = "ColTable", rng = 42, ) diff --git a/test/balanced_model.jl b/test/balanced_model.jl index 0552a5f..19576b6 100644 --- a/test/balanced_model.jl +++ b/test/balanced_model.jl @@ -1,7 +1,7 @@ @testset "BalancedModel" begin ### end-to-end test # Create and split data - X, y = generate_imbalanced_data(100, 5; probs = [0.2, 0.3, 0.5]) + X, y = generate_imbalanced_data(100, 5; class_probs = [0.2, 0.3, 0.5]) X = DataFrame(X) train_inds, test_inds = partition(eachindex(y), 0.8, shuffle = true, stratify = y, rng = Random.Xoshiro(42))