-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #154 from tlverse/devel
Update master from devel
- Loading branch information
Showing
293 changed files
with
37,642 additions
and
2,888 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,42 +1,55 @@ | ||
branches: | ||
only: | ||
- master | ||
- devel | ||
|
||
env: | ||
global: | ||
- RGL_USE_NULL=TRUE | ||
- PKG_CFLAGS="-O3 -Wall -pedantic" | ||
|
||
language: r | ||
dist: trusty | ||
sudo: required | ||
cache: packages | ||
|
||
cran: "http://cran.rstudio.com" | ||
warnings_are_errors: true | ||
r_check_args: '' | ||
r_build_args: '--no-manual' | ||
r_check_args: '--no-build-vignettes --no-manual' | ||
|
||
r: | ||
- release | ||
- devel | ||
before_install: Rscript -e 'update.packages(ask = FALSE)' | ||
|
||
before_install: | ||
- sudo apt-get -y install python3-pip python-dev | ||
- sudo pip install numpy | ||
- sudo pip install tensorflow | ||
- sudo pip install keras | ||
|
||
- sudo pip install numpy tensorflow keras | ||
|
||
r_packages: | ||
- devtools | ||
- covr | ||
- drat | ||
|
||
r_github_packages: | ||
- jimhester/covr | ||
|
||
|
||
script: | ||
- Rscript -e "devtools::install_github(c('osofr/condensier', | ||
'tlverse/delayed', | ||
'tlverse/origami', | ||
'jeremyrcoyle/hal9001'), | ||
upgrade_dependencies = FALSE)" | ||
|
||
#r_github_packages: | ||
#- osofr/condensier | ||
#- tlverse/delayed | ||
#- tlverse/origami | ||
#- jeremyrcoyle/hal9001 | ||
|
||
after_success: | ||
- travis_wait Rscript -e 'covr::codecov()' | ||
- test $TRAVIS_PULL_REQUEST == "true" && test $TRAVIS_BRANCH == "master" && bash deploy.sh | ||
|
||
on_failure: | ||
- travis_wait 80 Rscript -e 'covr::codecov()' | ||
|
||
_failure: | ||
- "./travis-tool.sh dump_logs" | ||
|
||
notifications: | ||
email: | ||
on_success: change | ||
on_failure: change | ||
|
||
env: | ||
global: | ||
- RGL_USE_NULL=TRUE | ||
- PKG_CFLAGS="-O3 -Wall -pedantic" | ||
- secure: IH0Tiyhb9aj5Rd/o44LiNf7L+mDTLhLHIVfv2iR8V7WKp5uT6QAvUzNvKLDhMwRfsNd9Wa9C67oXA6ROUoBkeyrETUT79BgZ+DG77EJ3i3XE153IHpGEFxW5gnEpFz4Sn6bS6qncfaKB2ocnJwByEfWCk2uMt5onBn7q5WAheuX0eeg6X3DJmJa+nTCAIQWRv/F0PLup5z0BTobAF2Qddp3KWug9WuWnyUPXJDLWww4IpU9V2P7DL9vsgwo/WqA59AbdPRqZbTCQh8kuJq2ETnQwfqwL6kofQnGeB/KNrIVLfvGucRQpZvF/7a1QDZvXd5RvQjBLS+8eoqb25bCSHtUr1UQI4Dpyf15LJThXksgPmm8pNdO7RYtjBENM0sD4eCyTwW/MXibGKCexgFlI0T4jwnyMpNOyZlefUolIjjjZh0e4wOpSY6kXSq9bz8EyloA0qcyfbJ5UKhwm9RVsXZCIbvSRtPyMoo8g8526Pbzs4eLND79mdQFy/2EhshCLnz/iH9QrNeNkWGbMHHUrkxChZWbP2LPt68PbYP3qPf3qbRmzamAjqO/hN4/xzls7/V23dW8aocHDSYi29R3UfKV69jgVYI3YXv7pUlxrqPNScpTZOxjAomZntWWGQVG6vhOyKuRc0z93X+zqaWdTUtVlcZKaEI+OPQeZWEFU9gY= | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#' BART Machine Learner | ||
#' | ||
#' This learner implements Bayesian Additive Regression Trees, using the | ||
#' \code{bartMachine} package. | ||
#' | ||
#' @docType class | ||
#' | ||
#' @importFrom R6 R6Class | ||
#' @importFrom stats predict | ||
#' @importFrom assertthat assert_that is.count is.flag | ||
#' | ||
#' @export | ||
#' | ||
#' @keywords data | ||
#' | ||
#' @return Learner object with methods for training and prediction. See | ||
#' \code{\link{Lrnr_base}} for documentation on learners. | ||
#' | ||
#' @format \code{\link{R6Class}} object. | ||
#' | ||
#' @family Learners | ||
#' | ||
#' @section Parameters: | ||
#' \describe{ | ||
#' \item{\code{Y}}{Outcome variable.} | ||
#' \item{\code{X}}{Covariate dataframe.} | ||
#' \item{\code{newX}}{Optional dataframe to predict the outcome.} | ||
#' \item{\code{obsWeights}}{Optional observation-level weights (supported but not tested).} | ||
#' \item{\code{id}}{Optional id to group observations from the same unit (not used | ||
#' currently).} | ||
#' \item{\code{family}}{"gaussian" for regression, "binomial" for binary classification.} | ||
#' \item{\code{num_trees }}{The number of trees to be grown in the sum-of-trees model.} | ||
#' \item{\code{num_burn_in}}{Number of MCMC samples to be discarded as "burn-in".} | ||
#' \item{\code{num_iterations_after_burn_in}}{Number of MCMC samples to draw from the | ||
#' posterior distribution of f(x).} | ||
#' \item{\code{alpha}}{Base hyperparameter in tree prior for whether a node is | ||
#' nonterminal or not.} | ||
#' \item{\code{beta}}{Power hyperparameter in tree prior for whether a node is | ||
#' nonterminal or not.} | ||
#' \item{\code{k}}{For regression, k determines the prior probability that E(Y|X) is | ||
#' contained in the interval (y_{min}, y_{max}), based on a normal | ||
#' distribution. For example, when k=2, the prior probability is 95\%. For | ||
#' classification, k determines the prior probability that E(Y|X) is between | ||
#' (-3,3). Note that a larger value of k results in more shrinkage and a more | ||
#' conservative fit.} | ||
#' \item{\code{q}}{Quantile of the prior on the error variance at which the data-based | ||
#' estimate is placed. Note that the larger the value of q, the more | ||
#' aggressive the fit as you are placing more prior weight on values lower | ||
#' than the data-based estimate. Not used for classification.} | ||
#' \item{\code{nu}}{Degrees of freedom for the inverse chi^2 prior. Not used for | ||
#' classification.} | ||
#' \item{\code{verbose }}{Prints information about progress of the algorithm to the | ||
#' screen.} | ||
#' | ||
#' } | ||
#' | ||
#' @template common_parameters | ||
# | ||
|
||
Lrnr_bartMachine <- R6Class( | ||
classname = "Lrnr_bartMachine", | ||
inherit = Lrnr_base, portable = TRUE, class = TRUE, | ||
public = list( | ||
initialize = function(num_trees = 50, num_burn_in = 250, verbose = F, | ||
alpha = 0.95, beta = 2, k = 2, q = 0.9, nu = 3, | ||
num_iterations_after_burn_in = 1000, | ||
prob_rule_class = 0.5, ...) { | ||
super$initialize(params = args_to_list(), ...) | ||
} | ||
), | ||
|
||
private = list( | ||
.properties = c("continuous", "binomial", "categorical", "weights"), | ||
|
||
.train = function(task) { | ||
args <- self$params | ||
outcome_type <- self$get_outcome_type(task) | ||
|
||
# specify data | ||
args$X <- as.data.frame(task$X) | ||
args$y <- outcome_type$format(task$Y) | ||
|
||
if (task$has_node("weights")) { | ||
args$weights <- task$weights | ||
} | ||
|
||
if (task$has_node("offset")) { | ||
args$offset <- task$offset | ||
} | ||
|
||
fit_object <- call_with_args(bartMachine::bartMachine, args) | ||
|
||
return(fit_object) | ||
}, | ||
|
||
.predict = function(task) { | ||
# outcome_type <- private$.training_outcome_type | ||
predictions <- stats::predict( | ||
private$.fit_object, | ||
new_data = data.frame(task$X) | ||
) | ||
|
||
return(predictions) | ||
}, | ||
.required_packages = c("rJava", "bartMachine") | ||
) | ||
) |
Oops, something went wrong.