From 3214c6dc709ca3628bc5e728144970a225a42e6e Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 16 Mar 2015 10:06:35 -0700 Subject: [PATCH] Merge pull request #217 from hlin09/cleanClosureFix Fix cleanClosure() on recursive function calls. --- pkg/R/utils.R | 77 +++++++++++++++++++++++-------------- pkg/inst/tests/test_utils.R | 35 ++++++++--------- 2 files changed, 66 insertions(+), 46 deletions(-) diff --git a/pkg/R/utils.R b/pkg/R/utils.R index 24b04436a33dc..6c99866d57020 100644 --- a/pkg/R/utils.R +++ b/pkg/R/utils.R @@ -320,51 +320,52 @@ listToSeq <- function(l) { # param # node The current AST node in the traversal. # oldEnv The original function environment. -# argNames A character vector of parameters of the function. Their values are -# passed in as arguments, and not included in the closure. -# newEnv A new function environment to store necessary function dependencies. -processClosure <- function(node, oldEnv, argNames, newEnv) { +# defVars An Accumulator of variables names defined in the function's calling environment, +# including function argument and local variable names. +# checkedFunc An environment of function objects examined during cleanClosure. It can +# be considered as a "name"-to-"list of functions" mapping. +# newEnv A new function environment to store necessary function dependencies, an output argument. +processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { nodeLen <- length(node) if (nodeLen > 1 && typeof(node) == "language") { # Recursive case: current AST node is an internal node, check for its children. if (length(node[[1]]) > 1) { for (i in 1:nodeLen) { - processClosure(node[[i]], oldEnv, argNames, newEnv) + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } } else { # if node[[1]] is length of 1, check for some R special functions. nodeChar <- as.character(node[[1]]) if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol. for (i in 2:nodeLen) { - processClosure(node[[i]], oldEnv, argNames, newEnv) + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } } else if (nodeChar == "<-" || nodeChar == "=" || nodeChar == "<<-") { # Assignment Ops. defVar <- node[[2]] if (length(defVar) == 1 && typeof(defVar) == "symbol") { - # Add the defined variable name into .defVars. - assign(".defVars", - c(get(".defVars", envir = .sparkREnv), as.character(defVar)), - envir = .sparkREnv) + # Add the defined variable name into defVars. + addItemToAccumulator(defVars, as.character(defVar)) } else { - processClosure(node[[2]], oldEnv, argNames, newEnv) + processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) } for (i in 3:nodeLen) { - processClosure(node[[i]], oldEnv, argNames, newEnv) + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } } else if (nodeChar == "function") { # Function definition. + # Add parameter names. newArgs <- names(node[[2]]) - argNames <- c(argNames, newArgs) # Add parameter names. + lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) }) for (i in 3:nodeLen) { - processClosure(node[[i]], oldEnv, argNames, newEnv) + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } } else if (nodeChar == "$") { # Skip the field. - processClosure(node[[2]], oldEnv, argNames, newEnv) + processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) } else if (nodeChar == "::" || nodeChar == ":::") { - processClosure(node[[3]], oldEnv, argNames, newEnv) + processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv) } else { for (i in 1:nodeLen) { - processClosure(node[[i]], oldEnv, argNames, newEnv) + processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } } } @@ -372,8 +373,7 @@ processClosure <- function(node, oldEnv, argNames, newEnv) { (typeof(node) == "symbol" || typeof(node) == "language")) { # Base case: current AST node is a leaf node and a symbol or a function call. nodeChar <- as.character(node) - if (!nodeChar %in% argNames && # Not a function parameter or function local variable. - !nodeChar %in% get(".defVars", envir = .sparkREnv)) { + if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable. func.env <- oldEnv topEnv <- parent.env(.GlobalEnv) # Search in function environment, and function's enclosing environments @@ -387,11 +387,27 @@ processClosure <- function(node, oldEnv, argNames, newEnv) { !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals. # Set parameter 'inherits' to FALSE since we do not need to search in # attached package environments. - if (exists(nodeChar, envir = func.env, inherits = FALSE)) { + if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE), + error = function(e) { FALSE })) { obj <- get(nodeChar, envir = func.env, inherits = FALSE) - if (is.function(obj)) { - # if the node is a function call, recursively clean its closure. - obj <- cleanClosure(obj) + if (is.function(obj)) { # If the node is a function call. + funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, + ifnotfound = list(list(NULL)))[[1]] + found <- sapply(funcList, function(func) { + ifelse(identical(func, obj), TRUE, FALSE) + }) + if (sum(found) > 0) { # If function has been examined, ignore. + break + } + # Function has not been examined, record it and recursively clean its closure. + assign(nodeChar, + if (is.null(funcList[[1]])) { + list(obj) + } else { + append(funcList, obj) + }, + envir = checkedFuncs) + obj <- cleanClosure(obj, checkedFuncs) } assign(nodeChar, obj, envir = newEnv) break @@ -410,19 +426,24 @@ processClosure <- function(node, oldEnv, argNames, newEnv) { # outside a UDF, and stores them in the function's environment. # param # func A function whose closure needs to be captured. +# checkedFunc An environment of function objects examined during cleanClosure. It can be +# considered as a "name"-to-"list of functions" mapping. # return value # a new version of func that has an correct environment (closure). -cleanClosure <- function(func) { +cleanClosure <- function(func, checkedFuncs = new.env()) { if (is.function(func)) { newEnv <- new.env(parent = .GlobalEnv) - # .defVars is a character vector of variables names defined in the function. - assign(".defVars", c(), envir = .sparkREnv) func.body <- body(func) oldEnv <- environment(func) + # defVars is an Accumulator of variables names defined in the function's calling + # environment. First, function's arguments are added to defVars. + defVars <- initAccumulator() argNames <- names(as.list(args(func))) - argsNames <- argNames[-length(argNames)] # Remove the ending NULL in pairlist. + for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist. + addItemToAccumulator(defVars, argNames[i]) + } # Recursively examine variables in the function body. - processClosure(func.body, oldEnv, argNames, newEnv) + processClosure(func.body, oldEnv, defVars, checkedFuncs, newEnv) environment(func) <- newEnv } func diff --git a/pkg/inst/tests/test_utils.R b/pkg/inst/tests/test_utils.R index 910b110c6e651..e9582db9d98dc 100644 --- a/pkg/inst/tests/test_utils.R +++ b/pkg/inst/tests/test_utils.R @@ -61,13 +61,26 @@ test_that("cleanClosure on R functions", { actual <- get("g", envir = env, inherits = FALSE) expect_equal(actual, g) - # Test for recursive closure capture for a free variable of a function. + base <- c(1, 2, 3) + l <- list(field = matrix(1)) + field <- matrix(2) + defUse <- 3 g <- function(x) { x + y } - f <- function(x) { lapply(x, g) + 1 } + f <- function(x) { + defUse <- base::as.integer(x) + 1 # Test for access operators `::`. + lapply(x, g) + 1 # Test for capturing function call "g"'s closure as a argument of lapply. + l$field[1,1] <- 3 # Test for access operators `$`. + res <- defUse + l$field[1,] # Test for def-use chain of "defUse", and "" symbol. + f(res) # Test for recursive calls. + } newF <- cleanClosure(f) env <- environment(newF) - expect_equal(length(ls(env)), 1) # Only "g". "y" should be in the environemnt of g. - expect_equal(ls(env), "g") + expect_equal(length(ls(env)), 3) # Only "g", "l" and "f". No "base", "field" or "defUse". + expect_true("g" %in% ls(env)) + expect_true("l" %in% ls(env)) + expect_true("f" %in% ls(env)) + expect_equal(get("l", envir = env, inherits = FALSE), l) + # "y" should be in the environemnt of g. newG <- get("g", envir = env, inherits = FALSE) env <- environment(newG) expect_equal(length(ls(env)), 1) @@ -83,20 +96,6 @@ test_that("cleanClosure on R functions", { env <- environment(newF) expect_equal(length(ls(env)), 0) # "y" and "g" should not be included. - # Test for access operators `$`, `::` and `:::`. - l <- list(a = 1) - a <- 2 - base <- c(1, 2, 3) - f <- function(x) { - z <- base::as.integer(x) + 1 - l$a <- 3 - z + l$a - } - newF <- cleanClosure(f) - env <- environment(newF) - expect_equal(ls(env), "l") # "base" and "a" should not be included. - expect_equal(get("l", envir = env, inherits = FALSE), l) - # Test for overriding variables in base namespace (Issue: SparkR-196). nums <- as.list(1:10) rdd <- parallelize(sc, nums, 2L)