Skip to content

Commit

Permalink
Merge pull request #35 from jmwerner/add_rf_tests
Browse files Browse the repository at this point in the history
Add fixed rf tests
  • Loading branch information
jmwerner authored Apr 10, 2019
2 parents 887e952 + 1f83806 commit 8d84d48
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 16 deletions.
6 changes: 1 addition & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,4 @@ As an exercise of professional development (but mostly for funzies) I will be wr

* [k-Means](http://jmwerner.github.io/ArtisanalMachineLearning/inst/doc/k-means.html)
* [Simple Neural Network](http://jmwerner.github.io/ArtisanalMachineLearning/inst/doc/neural-network.html)

### In Progress / On Deck

* Random Forest
* GBM
* [Trees](http://jmwerner.github.io/ArtisanalMachineLearning/inst/doc/trees.html)
16 changes: 5 additions & 11 deletions tests/testthat/test-trees.R
Original file line number Diff line number Diff line change
Expand Up @@ -290,24 +290,18 @@ test_that("random forest prediction works correctly", {
expect_error(predict(forest, data.frame(a=rnorm(10), b=rnorm(10), c = rnorm(10))), "ERROR: data must be a data.frame of dimension 1 x p")
expect_error(predict(forest, c(1, 2, 3)), "ERROR: data must be a data.frame")

expected_predictions = c(1.19124714218882, 1.22298896784481, 1.21150233682877, 1.21104603980511,1.20929176693527, 1.1690221455812, 1.224876994208, 1.19734883889557,1.22298896784481, 1.21104603980511, 1.15802921771698, 1.2129340661683,1.22298896784481, 1.22298896784481, 1.16997214575668, 1.15802921771698,1.16997214575668, 1.19124714218882, 1.1690221455812, 1.17930421414912,1.1690221455812, 1.17930421414912, 1.224876994208, 1.19238736757725,1.23670881221433, 1.19546081253238, 1.19734883889557, 1.17930421414912,1.19124714218882, 1.19955940878906, 1.21104603980511, 1.15802921771698,1.17930421414912, 1.16997214575668, 1.21104603980511, 1.19591710955604,1.16997214575668, 1.224876994208, 1.22298896784481, 1.17930421414912,1.20929176693527, 1.21310171294285, 1.21150233682877, 1.19734883889557,1.20307896019515, 1.22298896784481, 1.17930421414912, 1.21150233682877,1.17930421414912, 1.21138199249918, 1.18907102305795, 1.16974940143633,1.20055765407399, 1.17060029270193, 1.19067039917203, 1.18999423209587,1.22318136054655, 1.22493353094918, 1.20055765407399, 1.18880959190027,1.20934830367645, 1.19988148699783, 1.17060029270193, 1.21920310861945,1.16806688079605, 1.18123603245237, 1.19988148699783, 1.17060029270193,1.18999423209587, 1.16753459546813, 1.20771647760341, 1.17060029270193,1.20931585371749, 1.20931585371749, 1.18123603245237, 1.18123603245237,1.19067039917203, 1.18880007831642, 1.19988148699783, 1.17002868249786,1.16753459546813, 1.16753459546813, 1.16753459546813, 1.21761818415779,1.19988148699783, 1.18460467819618, 1.20055765407399, 1.18999423209587,1.18048754760389, 1.17060029270193, 1.18999423209587, 1.21920310861945,1.17060029270193, 1.20934830367645, 1.18999423209587, 1.19988148699783,1.19988148699783, 1.19988148699783, 1.19130367893, 1.17060029270193,1.23075287396293, 1.21761818415779, 1.20812916749037, 1.22677462203583,1.20812916749037, 1.20812916749037, 1.21615017157921, 1.20812916749037,1.19824191258841, 1.20459594843232, 1.19737335349825, 1.19897272961233,1.20885998451429, 1.19755827795991, 1.21761818415779, 1.19737335349825,1.20885998451429, 1.20459594843232, 1.19824191258841, 1.19755827795991,1.19664253647433, 1.20931585371749, 1.19824191258841, 1.20931585371749,1.21210741941747, 1.19664253647433, 1.20931585371749, 1.21920310861945,1.19824191258841, 1.20812916749037, 1.19824191258841, 1.20459594843232,1.19824191258841, 1.21761818415779, 1.21688736713387, 1.20812916749037,1.22324140297777, 1.20885998451429, 1.21920310861945, 1.20885998451429,1.20812916749037, 1.20885998451429, 1.21761818415779, 1.19664253647433,1.21210741941747, 1.20885998451429, 1.19755827795991, 1.20885998451429,1.22397222000169, 1.22750543905975)
expected_predictions = c(0.302354120762067, 0.290477871864219, 0.290477871864219, 0.293961742831961, 0.29461251422046, 0.305837991729809, 0.290477871864219, 0.298096385188202, 0.290477871864219, 0.293961742831961, 0.305837991729809, 0.293961742831961, 0.290477871864219, 0.290477871864219, 0.577327158467041, 0.533238711341917, 0.302354120762067, 0.302354120762067, 0.533238711341917, 0.305837991729809, 0.305837991729809, 0.305837991729809, 0.290477871864219, 0.305837991729809, 0.293961742831961, 0.298096385188202, 0.298096385188202, 0.305837991729809, 0.302354120762067, 0.293961742831961, 0.293961742831961, 0.305837991729809, 0.305837991729809, 0.302354120762067, 0.293961742831961, 0.29461251422046, 0.302354120762067, 0.290477871864219, 0.290477871864219, 0.305837991729809, 0.29461251422046, 0.290477871864219, 0.290477871864219, 0.298096385188202, 0.305837991729809, 0.290477871864219, 0.305837991729809, 0.290477871864219, 0.305837991729809, 0.29461251422046, 1.47450315339333, 1.47450315339333, 1.86984880426984, 1.07835265698102, 1.45830025484261, 1.35842601367282, 1.47450315339333, 0.960267890280962, 1.45830025484261, 1.07835265698102, 0.960267890280962, 1.36459309395339, 1.35332569468599, 1.40599833176569, 1.26681337659312, 1.43309791558104, 1.35842601367282, 1.35332569468599, 1.45830025484261, 1.30575337659312, 1.71687844534765, 1.35332569468599, 1.85364590571911, 1.40599833176569, 1.41689501703031, 1.41689501703031, 1.76918036842457, 1.85364590571911, 1.40599833176569, 1.18766860989307, 1.07835265698102, 1.03941265698102, 1.35332569468599, 1.85110294526614, 1.13102529406071, 1.40599833176569, 1.47450315339333, 1.41689501703031, 1.30575337659312, 1.07835265698102, 1.08962005624841, 1.40599833176569, 1.35332569468599, 0.960267890280962, 1.31702077586052, 1.31702077586052, 1.31702077586052, 1.41689501703031, 0.960267890280962, 1.30575337659312, 1.99010540973436, 1.85110294526614, 1.96176172443416, 1.96176172443416, 1.96176172443416, 1.96176172443416, 1.13102529406071, 1.96176172443416, 1.96176172443416, 1.99010540973436, 1.93174855364327, 1.95005482788243, 1.95005482788243, 1.75377166454932, 1.85110294526614, 1.97839851318264, 1.95005482788243, 1.99010540973436, 1.96176172443416, 1.80134398264219, 1.99010540973436, 1.75377166454932, 1.96176172443416, 1.85364590571911, 1.99010540973436, 1.99010540973436, 1.76918036842457, 1.80134398264219, 1.96176172443416, 1.96176172443416, 1.96176172443416, 1.99010540973436, 1.96176172443416, 1.90340486834306, 1.90945980135723, 1.96176172443416, 1.99010540973436, 1.97839851318264, 1.71687844534765, 1.97839851318264, 1.99010540973436, 1.93174855364327, 1.85110294526614, 1.99010540973436, 1.99010540973436, 1.95005482788243, 1.85364590571911, 1.95005482788243, 1.97839851318264, 1.85110294526614)
predictions = sapply(1:nrow(data), function(i){predict(forest, data[i,])})
expect_equal(expected_predictions, predictions, tolerance = .00001)

set.seed(8675309)
forest = aml_random_forest(data, response, evaluation_criterion = sum_of_squares, b = 2, m = 2, min_obs = 20, max_depth = 20)

expect_equal(predict(forest, data.frame(a=1, b=1, c=1)), mean(c(1.215254, 0.9857143)), tolerance = .00001)
expect_equal(predict(forest, data.frame(a=1, b=10, c=1)), mean(c(1.215254, 1.264348)), tolerance = .00001)
expect_equal(predict(forest, data.frame(a=1, b=1, c=1)), mean(c(0.2375, 0.2410714)), tolerance = .00001)
expect_equal(predict(forest, data.frame(a=1, b=10, c=2.5)), mean(c(0.2375, 1.3)), tolerance = .00001)

expect_equal(predict(forest, data.frame(a=10, b=1, c=1)), mean(c(1.470968, 0.9857143)), tolerance = .00001)
expect_equal(predict(forest, data.frame(a=10, b=10, c=1)), mean(c(1.470968, 1.264348)), tolerance = .00001)

expect_equal(predict(forest, data.frame(a=10, b=1, c=10)), mean(c(1.3, 0.9857143)), tolerance = .00001)
expect_equal(predict(forest, data.frame(a=10, b=10, c=10)), mean(c(1.3, 1.264348)), tolerance = .00001)

expect_equal(predict(forest, data.frame(a=1, b=1, c=10)), mean(c(0.8947368, 0.9857143)), tolerance = .00001)
expect_equal(predict(forest, data.frame(a=1, b=10, c=10)), mean(c(0.8947368, 1.264348)), tolerance = .00001)
expect_equal(predict(forest, data.frame(a=10, b=1, c=10)), mean(c(1.973333, 1.998039)), tolerance = .00001)
expect_equal(predict(forest, data.frame(a=10, b=10, c=4)), mean(c(1.352381, 1.3)), tolerance = .00001)
})


Expand Down

0 comments on commit 8d84d48

Please sign in to comment.