From d9da4519fa9efa5db769a43cb75d296a94d44a74 Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 4 Feb 2015 21:46:49 +0800 Subject: [PATCH] [SPARKR-150] phase 1: implement sortBy() and sortByKey(). --- pkg/NAMESPACE | 2 + pkg/R/RDD.R | 106 ++++++++++++++++++++++++++++++++++++++ pkg/R/utils.R | 4 +- pkg/inst/tests/test_rdd.R | 14 +++++ pkg/man/sortBy.Rd | 36 +++++++++++++ pkg/man/sortByKey.Rd | 34 ++++++++++++ 6 files changed, 194 insertions(+), 2 deletions(-) create mode 100644 pkg/man/sortBy.Rd create mode 100644 pkg/man/sortByKey.Rd diff --git a/pkg/NAMESPACE b/pkg/NAMESPACE index 6977dd2d43330..ac90c36734e6d 100644 --- a/pkg/NAMESPACE +++ b/pkg/NAMESPACE @@ -44,6 +44,8 @@ exportMethods( "sampleRDD", "saveAsTextFile", "saveAsObjectFile", + "sortBy", + "sortByKey", "take", "takeSample", "unionRDD", diff --git a/pkg/R/RDD.R b/pkg/R/RDD.R index 66e616f252533..b788f13119cb6 100644 --- a/pkg/R/RDD.R +++ b/pkg/R/RDD.R @@ -1240,6 +1240,40 @@ setMethod("flatMapValues", flatMap(X, flatMapFunc) }) +#' Sort an RDD by the given key function. +#' +#' @param rdd An RDD to be sorted. +#' @param func A function used to compute the sort key for each element. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all elements are sorted. +#' @rdname sortBy +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(3, 2, 1)) +#' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) +#'} +setGeneric("sortBy", function(rdd, func, ascending, numPartitions) { standardGeneric("sortBy") }) + +setClassUnion("missingOrLogical", c("missing", "logical")) +#' @rdname sortBy +#' @aliases sortBy,RDD,RDD-method +setMethod("sortBy", + signature(rdd = "RDD", func = "function", + ascending = "missingOrLogical", numPartitions = "missingOrInteger"), + function(rdd, func, ascending, numPartitions) { + if (missing(ascending)) { + ascending = TRUE + } + if (missing(numPartitions)) { + numPartitions = SparkR::numPartitions(rdd) + } + + values(sortByKey(keyBy(rdd, func), ascending, numPartitions)) + }) + ############ Shuffle Functions ############ #' Partition an RDD by key @@ -1796,6 +1830,78 @@ setMethod("cogroup", group.func) }) +#' Sort an (k, v) pair RDD by k. +#' +#' @param rdd An (k, v) pair RDD to be sorted. +#' @param ascending A flag to indicate whether the sorting is ascending or descending. +#' @param numPartitions Number of partitions to create. +#' @return An RDD where all (k, v) pair elements are sorted. +#' @rdname sortByKey +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' rdd <- parallelize(sc, list(list(3, 3), list(2, 2), list(1, 1))) +#' collect(sortByKey(rdd)) # list (list(1, 1), list(2, 2), list(3, 3)) +#'} +setGeneric("sortByKey", function(rdd, ascending, numPartitions) { standardGeneric("sortByKey") }) + +#' @rdname sortByKey +#' @aliases sortByKey,RDD,RDD-method +setMethod("sortByKey", + signature(rdd = "RDD", ascending = "missingOrLogical", numPartitions = "missingOrInteger"), + function(rdd, ascending, numPartitions) { + if (missing(ascending)) { + ascending = TRUE + } + if (missing(numPartitions)) { + numPartitions = SparkR::numPartitions(rdd) + } + + rangeBounds <- list() + + if (numPartitions > 1) { + rddSize <- count(rdd) + # constant from Spark's RangePartitioner + maxSampleSize <- numPartitions * 20 + fraction <- min(maxSampleSize / max(rddSize, 1), 1.0) + + samples <- collect(keys(sampleRDD(rdd, FALSE, fraction, 1L))) + + # Note: the built-in R sort() function only atomic vectors + samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending) + + if (length(samples) > 0) { + rangeBounds <- lapply(seq_len(numPartitions - 1), + function(i) { + j <- ceiling(length(samples) * i / numPartitions) + samples[j] + }) + } + } + + rangePartitionFunc <- function(key) { + partition <- 0 + + while (partition < length(rangeBounds) && key > rangeBounds[[partition + 1]]) { + partition <- partition + 1 + } + + if (ascending) { + partition + } else { + numPartitions - partition - 1 + } + } + + partitionFunc <- function(part) { + sortKeyValueList(part, decreasing = !ascending) + } + + newRDD <- partitionBy(rdd, numPartitions, rangePartitionFunc) + lapplyPartition(newRDD, partitionFunc) + }) + # TODO: Consider caching the name in the RDD's environment #' Return an RDD's name. #' diff --git a/pkg/R/utils.R b/pkg/R/utils.R index 7c2a153b8b55d..8bb77463602b4 100644 --- a/pkg/R/utils.R +++ b/pkg/R/utils.R @@ -197,9 +197,9 @@ initAccumulator <- function() { # Utility function to sort a list of key value pairs # Used in unit tests -sortKeyValueList <- function(kv_list) { +sortKeyValueList <- function(kv_list, decreasing = FALSE) { keys <- sapply(kv_list, function(x) x[[1]]) - kv_list[order(keys)] + kv_list[order(keys, decreasing = decreasing)] } # Utility function to generate compact R lists from grouped rdd diff --git a/pkg/inst/tests/test_rdd.R b/pkg/inst/tests/test_rdd.R index 7c2599f51e7e5..5f675fcd689ba 100644 --- a/pkg/inst/tests/test_rdd.R +++ b/pkg/inst/tests/test_rdd.R @@ -254,6 +254,12 @@ test_that("keyBy on RDDs", { expect_equal(actual, lapply(nums, function(x) { list(func(x), x) })) }) +test_that("sortBy() on RDDs", { + sortedRdd <- sortBy(rdd, function(x) { x }, ascending = FALSE) + actual <- collect(sortedRdd) + expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) +}) + test_that("keys() on RDDs", { keys <- keys(intRdd) actual <- collect(keys) @@ -373,3 +379,11 @@ test_that("fullOuterJoin() on pairwise RDDs", { expect_equal(sortKeyValueList(actual), sortKeyValueList(list(list("a", list(1, NULL)), list("b", list(2, NULL)), list("d", list(NULL, 4)), list("c", list(NULL, 3))))) }) + +test_that("sortByKey() on pairwise RDDs", { + numPairsRdd <- map(rdd, function(x) { list (x, x) }) + sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) + actual <- collect(sortedRdd) + numPairs <- lapply(nums, function(x) { list (x, x) }) + expect_equal(actual, sortKeyValueList(numPairs, decreasing = TRUE)) +}) \ No newline at end of file diff --git a/pkg/man/sortBy.Rd b/pkg/man/sortBy.Rd new file mode 100644 index 0000000000000..d3a231c745240 --- /dev/null +++ b/pkg/man/sortBy.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2 (4.0.2): do not edit by hand +\docType{methods} +\name{sortBy} +\alias{sortBy} +\alias{sortBy,RDD,RDD-method} +\alias{sortBy,RDD,function,missingOrLogical,missingOrInteger-method} +\title{Sort an RDD by the given key function.} +\usage{ +sortBy(rdd, func, ascending, numPartitions) + +\S4method{sortBy}{RDD,`function`,missingOrLogical,missingOrInteger}(rdd, func, + ascending, numPartitions) +} +\arguments{ +\item{rdd}{An RDD to be sorted.} + +\item{func}{A function used to compute the sort key for each element.} + +\item{ascending}{A flag to indicate whether the sorting is ascending or descending.} + +\item{numPartitions}{Number of partitions to create.} +} +\value{ +An RDD where all elements are sorted. +} +\description{ +Sort an RDD by the given key function. +} +\examples{ +\dontrun{ +sc <- sparkR.init() +rdd <- parallelize(sc, list(3, 2, 1)) +collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) +} +} + diff --git a/pkg/man/sortByKey.Rd b/pkg/man/sortByKey.Rd new file mode 100644 index 0000000000000..b58dd5bf22ce6 --- /dev/null +++ b/pkg/man/sortByKey.Rd @@ -0,0 +1,34 @@ +% Generated by roxygen2 (4.0.2): do not edit by hand +\docType{methods} +\name{sortByKey} +\alias{sortByKey} +\alias{sortByKey,RDD,RDD-method} +\alias{sortByKey,RDD,missingOrLogical,missingOrInteger-method} +\title{Sort an (k, v) pair RDD by k.} +\usage{ +sortByKey(rdd, ascending, numPartitions) + +\S4method{sortByKey}{RDD,missingOrLogical,missingOrInteger}(rdd, ascending, + numPartitions) +} +\arguments{ +\item{rdd}{An (k, v) pair RDD to be sorted.} + +\item{ascending}{A flag to indicate whether the sorting is ascending or descending.} + +\item{numPartitions}{Number of partitions to create.} +} +\value{ +An RDD where all (k, v) pair elements are sorted. +} +\description{ +Sort an (k, v) pair RDD by k. +} +\examples{ +\dontrun{ +sc <- sparkR.init() +rdd <- parallelize(sc, list(list(3, 3), list(2, 2), list(1, 1))) +collect(sortByKey(rdd)) # list (list(1, 1), list(2, 2), list(3, 3)) +} +} +