Skip to content

Commit

Permalink
Merge pull request apache#217 from hlin09/cleanClosureFix
Browse files Browse the repository at this point in the history
Fix cleanClosure() on recursive function calls.
  • Loading branch information
shivaram committed Mar 16, 2015
1 parent f5d3355 commit 3214c6d
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 46 deletions.
77 changes: 49 additions & 28 deletions pkg/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -320,60 +320,60 @@ listToSeq <- function(l) {
# param
# node The current AST node in the traversal.
# oldEnv The original function environment.
# argNames A character vector of parameters of the function. Their values are
# passed in as arguments, and not included in the closure.
# newEnv A new function environment to store necessary function dependencies.
processClosure <- function(node, oldEnv, argNames, newEnv) {
# defVars An Accumulator of variables names defined in the function's calling environment,
# including function argument and local variable names.
# checkedFunc An environment of function objects examined during cleanClosure. It can
# be considered as a "name"-to-"list of functions" mapping.
# newEnv A new function environment to store necessary function dependencies, an output argument.
processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
nodeLen <- length(node)

if (nodeLen > 1 && typeof(node) == "language") {
# Recursive case: current AST node is an internal node, check for its children.
if (length(node[[1]]) > 1) {
for (i in 1:nodeLen) {
processClosure(node[[i]], oldEnv, argNames, newEnv)
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else { # if node[[1]] is length of 1, check for some R special functions.
nodeChar <- as.character(node[[1]])
if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol.
for (i in 2:nodeLen) {
processClosure(node[[i]], oldEnv, argNames, newEnv)
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else if (nodeChar == "<-" || nodeChar == "=" ||
nodeChar == "<<-") { # Assignment Ops.
defVar <- node[[2]]
if (length(defVar) == 1 && typeof(defVar) == "symbol") {
# Add the defined variable name into .defVars.
assign(".defVars",
c(get(".defVars", envir = .sparkREnv), as.character(defVar)),
envir = .sparkREnv)
# Add the defined variable name into defVars.
addItemToAccumulator(defVars, as.character(defVar))
} else {
processClosure(node[[2]], oldEnv, argNames, newEnv)
processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv)
}
for (i in 3:nodeLen) {
processClosure(node[[i]], oldEnv, argNames, newEnv)
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else if (nodeChar == "function") { # Function definition.
# Add parameter names.
newArgs <- names(node[[2]])
argNames <- c(argNames, newArgs) # Add parameter names.
lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) })
for (i in 3:nodeLen) {
processClosure(node[[i]], oldEnv, argNames, newEnv)
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
} else if (nodeChar == "$") { # Skip the field.
processClosure(node[[2]], oldEnv, argNames, newEnv)
processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv)
} else if (nodeChar == "::" || nodeChar == ":::") {
processClosure(node[[3]], oldEnv, argNames, newEnv)
processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv)
} else {
for (i in 1:nodeLen) {
processClosure(node[[i]], oldEnv, argNames, newEnv)
processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv)
}
}
}
} else if (nodeLen == 1 &&
(typeof(node) == "symbol" || typeof(node) == "language")) {
# Base case: current AST node is a leaf node and a symbol or a function call.
nodeChar <- as.character(node)
if (!nodeChar %in% argNames && # Not a function parameter or function local variable.
!nodeChar %in% get(".defVars", envir = .sparkREnv)) {
if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable.
func.env <- oldEnv
topEnv <- parent.env(.GlobalEnv)
# Search in function environment, and function's enclosing environments
Expand All @@ -387,11 +387,27 @@ processClosure <- function(node, oldEnv, argNames, newEnv) {
!(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals.
# Set parameter 'inherits' to FALSE since we do not need to search in
# attached package environments.
if (exists(nodeChar, envir = func.env, inherits = FALSE)) {
if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE),
error = function(e) { FALSE })) {
obj <- get(nodeChar, envir = func.env, inherits = FALSE)
if (is.function(obj)) {
# if the node is a function call, recursively clean its closure.
obj <- cleanClosure(obj)
if (is.function(obj)) { # If the node is a function call.
funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F,
ifnotfound = list(list(NULL)))[[1]]
found <- sapply(funcList, function(func) {
ifelse(identical(func, obj), TRUE, FALSE)
})
if (sum(found) > 0) { # If function has been examined, ignore.
break
}
# Function has not been examined, record it and recursively clean its closure.
assign(nodeChar,
if (is.null(funcList[[1]])) {
list(obj)
} else {
append(funcList, obj)
},
envir = checkedFuncs)
obj <- cleanClosure(obj, checkedFuncs)
}
assign(nodeChar, obj, envir = newEnv)
break
Expand All @@ -410,19 +426,24 @@ processClosure <- function(node, oldEnv, argNames, newEnv) {
# outside a UDF, and stores them in the function's environment.
# param
# func A function whose closure needs to be captured.
# checkedFunc An environment of function objects examined during cleanClosure. It can be
# considered as a "name"-to-"list of functions" mapping.
# return value
# a new version of func that has an correct environment (closure).
cleanClosure <- function(func) {
cleanClosure <- function(func, checkedFuncs = new.env()) {
if (is.function(func)) {
newEnv <- new.env(parent = .GlobalEnv)
# .defVars is a character vector of variables names defined in the function.
assign(".defVars", c(), envir = .sparkREnv)
func.body <- body(func)
oldEnv <- environment(func)
# defVars is an Accumulator of variables names defined in the function's calling
# environment. First, function's arguments are added to defVars.
defVars <- initAccumulator()
argNames <- names(as.list(args(func)))
argsNames <- argNames[-length(argNames)] # Remove the ending NULL in pairlist.
for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist.
addItemToAccumulator(defVars, argNames[i])
}
# Recursively examine variables in the function body.
processClosure(func.body, oldEnv, argNames, newEnv)
processClosure(func.body, oldEnv, defVars, checkedFuncs, newEnv)
environment(func) <- newEnv
}
func
Expand Down
35 changes: 17 additions & 18 deletions pkg/inst/tests/test_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,26 @@ test_that("cleanClosure on R functions", {
actual <- get("g", envir = env, inherits = FALSE)
expect_equal(actual, g)

# Test for recursive closure capture for a free variable of a function.
base <- c(1, 2, 3)
l <- list(field = matrix(1))
field <- matrix(2)
defUse <- 3
g <- function(x) { x + y }
f <- function(x) { lapply(x, g) + 1 }
f <- function(x) {
defUse <- base::as.integer(x) + 1 # Test for access operators `::`.
lapply(x, g) + 1 # Test for capturing function call "g"'s closure as a argument of lapply.
l$field[1,1] <- 3 # Test for access operators `$`.
res <- defUse + l$field[1,] # Test for def-use chain of "defUse", and "" symbol.
f(res) # Test for recursive calls.
}
newF <- cleanClosure(f)
env <- environment(newF)
expect_equal(length(ls(env)), 1) # Only "g". "y" should be in the environemnt of g.
expect_equal(ls(env), "g")
expect_equal(length(ls(env)), 3) # Only "g", "l" and "f". No "base", "field" or "defUse".
expect_true("g" %in% ls(env))
expect_true("l" %in% ls(env))
expect_true("f" %in% ls(env))
expect_equal(get("l", envir = env, inherits = FALSE), l)
# "y" should be in the environemnt of g.
newG <- get("g", envir = env, inherits = FALSE)
env <- environment(newG)
expect_equal(length(ls(env)), 1)
Expand All @@ -83,20 +96,6 @@ test_that("cleanClosure on R functions", {
env <- environment(newF)
expect_equal(length(ls(env)), 0) # "y" and "g" should not be included.

# Test for access operators `$`, `::` and `:::`.
l <- list(a = 1)
a <- 2
base <- c(1, 2, 3)
f <- function(x) {
z <- base::as.integer(x) + 1
l$a <- 3
z + l$a
}
newF <- cleanClosure(f)
env <- environment(newF)
expect_equal(ls(env), "l") # "base" and "a" should not be included.
expect_equal(get("l", envir = env, inherits = FALSE), l)

# Test for overriding variables in base namespace (Issue: SparkR-196).
nums <- as.list(1:10)
rdd <- parallelize(sc, nums, 2L)
Expand Down

0 comments on commit 3214c6d

Please sign in to comment.