Skip to content

Commit

Permalink
Added some unit tests and fixed minor issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
jwood000 committed Sep 28, 2023
1 parent faf28fc commit e49a928
Show file tree
Hide file tree
Showing 9 changed files with 264 additions and 18 deletions.
5 changes: 4 additions & 1 deletion inst/include/ComboGroups/ComboGroupsClass.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ class ComboGroupsClass : public Combo {
cpp11::writable::list dimNames;
cpp11::writable::strings myNames;

std::string grpSizeDesc;

bool IsArray;
int r; // Number of groups
int rDisp; // This will differ in the General case when OneGrp = true
int r; // Number of groups
const std::unique_ptr<ComboGroupsTemplate> CmbGrp;

SEXP SingleReturn();
Expand Down
1 change: 0 additions & 1 deletion inst/include/ComboGroups/ComboGroupsGeneral.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ class ComboGroupsGeneral : public ComboGroupsTemplate {
private:

const GroupHelper MyGrp;
const bool OneGrp;

public:

Expand Down
4 changes: 4 additions & 0 deletions inst/include/ComboGroups/ComboGroupsTemplate.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class ComboGroupsTemplate {

std::string GroupType;

bool OneGrp; // Used only for General case, but we need to be able to
// to access it, so that we can properly name the output
// when using iterables (e.g. nextIter method)
const int n; // Size of vector which is also the size of z (i.e. z.size())
const int r; // Number of groups

Expand Down Expand Up @@ -68,6 +71,7 @@ class ComboGroupsTemplate {
bool GetIsGmp() const {return IsGmp;}
int GetNumGrps() const {return r;}
std::string GetType() const {return GroupType;}
bool GetOneGrp() const {return OneGrp;}

SEXP GetCount() const {
return CppConvert::GetCount(IsGmp, computedRowsMpz, computedRows);
Expand Down
2 changes: 1 addition & 1 deletion inst/include/ComboGroups/GetComboGroups.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ SEXP GetComboGroups(
SEXP Rv, nextGrpFunc nextCmbGrp, nthFuncDbl nthCmbGrp,
nthFuncGmp nthCmbGrpGmp, finalTouchFunc FinalTouch,
const std::vector<double> &vNum, const std::vector<int> &vInt,
std::vector<int> &startZ, const VecType &myType,
std::vector<int> startZ, const VecType &myType,
const std::vector<double> &mySample, const std::vector<mpz_class> &myVec,
mpz_class lowerMpz, double lower, int n, int numResults, int nThreads,
bool IsArray, bool IsNamed, bool Parallel, bool IsSample, bool IsGmp
Expand Down
51 changes: 39 additions & 12 deletions src/ComboGroupsClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ SEXP ComboGroupsClass::GeneralReturn(int numResults) {

SetThreads(LocalPar, maxThreads, numResults,
myType, nThreads, sexpNThreads, limit);

CmbGrpClsFuncs f = GetClassFuncs(CmbGrp);

cpp11::sexp res = GetComboGroups(
Expand All @@ -33,6 +32,7 @@ SEXP ComboGroupsClass::GeneralReturn(int numResults) {
numResults, nThreads, IsArray, false, LocalPar, false, IsGmp
);

zUpdateIndex(vNum, vInt, z, sexpVec, res, m, numResults);
return res;
}

Expand All @@ -46,6 +46,7 @@ ComboGroupsClass::ComboGroupsClass(
RmaxThreads, RnumThreads, Rparallel),
CmbGrp(GroupPrep(Rv, RNumGroups, RGrpSize, n)) {

prevIterAvailable = false;
CmbGrp->SetCount();
IsGmp = CmbGrp->GetIsGmp();
computedRows = CmbGrp->GetDblCount();
Expand All @@ -68,6 +69,7 @@ ComboGroupsClass::ComboGroupsClass(

IsArray = (retType == "3Darray");
r = CmbGrp->GetNumGrps();
rDisp = r;
const int grpSize = n / r;

std::vector<std::string> myColNames(r, "Grp");
Expand All @@ -76,6 +78,10 @@ ComboGroupsClass::ComboGroupsClass(
myColNames[j] += std::to_string(j + 1);
}

for (auto g: CmbGrp->GetGroupSizes()) {
grpSizeDesc += (std::to_string(g) + ", ");
}

if (IsArray) {
myNames.resize(r);

Expand All @@ -96,6 +102,33 @@ ComboGroupsClass::ComboGroupsClass(
myNames[k] = myColNames[i].c_str();
}
}
} else if (CmbGrp->GetOneGrp()) {
myNames.resize(n);
std::vector<int> vGrpSizes(CmbGrp->GetGroupSizes());

const int numOneGrps = vGrpSizes.front();
std::vector<int> realGrps(vGrpSizes);
realGrps.erase(realGrps.begin());
realGrps.insert(realGrps.begin(), numOneGrps, 1);

rDisp = realGrps.size();
std::vector<std::string> myColNamesOne(rDisp, "Grp");

for (int j = 0; j < rDisp; ++j) {
myColNamesOne[j] += std::to_string(j + 1);
}

for (int i = 0, k = 0; i < rDisp; ++i) {
for (int j = 0; j < realGrps[i]; ++j, ++k) {
myNames[k] = myColNamesOne[i].c_str();
}
}

grpSizeDesc.clear();

for (auto g: realGrps) {
grpSizeDesc += (std::to_string(g) + ", ");
}
} else {
myNames.resize(n);
std::vector<int> vGrpSizes(CmbGrp->GetGroupSizes());
Expand All @@ -106,6 +139,10 @@ ComboGroupsClass::ComboGroupsClass(
}
}
}

// Remove the last space and comma
grpSizeDesc.pop_back();
grpSizeDesc.pop_back();
}

void ComboGroupsClass::startOver() {
Expand Down Expand Up @@ -308,19 +345,9 @@ SEXP ComboGroupsClass::back() {

SEXP ComboGroupsClass::summary() {

std::string grpSizeDesc;

for (auto g: CmbGrp->GetGroupSizes()) {
grpSizeDesc += (std::to_string(g) + ", ");
}

// Remove the last space and comma
grpSizeDesc.pop_back();
grpSizeDesc.pop_back();

const std::string gtype = CmbGrp->GetType();
const std::string prefix = "Partition of v of length " +
std::to_string(n) + " into " + std::to_string(r);
std::to_string(n) + " into " + std::to_string(rDisp);
const std::string suffix = (gtype == "Uniform") ? " uniform groups" :
" groups of sizes: " + grpSizeDesc;

Expand Down
3 changes: 2 additions & 1 deletion src/ComboGroupsGeneral.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ ComboGroupsGeneral::ComboGroupsGeneral(
int n_, int numGroups, int i1, int i2,
int bnd, GroupHelper MyGrp_, bool OneGrp_
) : ComboGroupsTemplate(n_, numGroups, i1, i2, bnd),
MyGrp(MyGrp_), OneGrp(OneGrp_) {
MyGrp(MyGrp_) {

OneGrp = OneGrp_;
GroupType = "General";
}

Expand Down
5 changes: 4 additions & 1 deletion src/ComboGroupsTemplate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

ComboGroupsTemplate::ComboGroupsTemplate(
int n_, int numGroups, int i1, int i2, int bnd
) : n(n_), r(numGroups), idx1(i1), idx2(i2), curr_bnd(bnd) {}
) : n(n_), r(numGroups), idx1(i1), idx2(i2), curr_bnd(bnd) {

OneGrp = false;
}

void ComboGroupsTemplate::SetCount() {
computedRows = numGroupCombs();
Expand Down
2 changes: 1 addition & 1 deletion src/GetComboGroups.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ SEXP GetComboGroups(
SEXP Rv, nextGrpFunc nextCmbGrp, nthFuncDbl nthCmbGrp,
nthFuncGmp nthCmbGrpGmp, finalTouchFunc FinalTouch,
const std::vector<double> &vNum, const std::vector<int> &vInt,
std::vector<int> &startZ, const VecType &myType,
std::vector<int> startZ, const VecType &myType,
const std::vector<double> &mySample,
const std::vector<mpz_class> &myBigSamp,
mpz_class lowerMpz, double lower, int n, int numResults, int nThreads,
Expand Down
209 changes: 209 additions & 0 deletions tests/testthat/testComboGroupsClass.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
test_that("comboGroupsIter produces correct results", {

comboGroupsClassTest <- function(
v_pass, n_grps = NULL, grp_sizes = NULL,
ret = "matrix", testRand = TRUE
) {

myResults <- vector(mode = "logical")

myRows <- comboGroupsCount(v_pass, n_grps, grp_sizes)
a <- comboGroupsIter(v_pass, n_grps, grp_sizes, ret)
b <- comboGroups(v_pass, n_grps, grp_sizes, ret)

myResults <- c(myResults, isTRUE(all.equal(
a@summary()$totalResults, myRows)
))

if (length(v_pass) == 1 && v_pass == 0) {
myResults <- c(myResults, v_pass == a@sourceVector())
} else if (length(v_pass) == 1) {
myResults <- c(myResults, isTRUE(
all.equal(abs(v_pass), length(a@sourceVector()))
))
} else {
myResults <- c(myResults, isTRUE(
all.equal(sort(v_pass), a@sourceVector())
))
}

if (testRand) {
myResults <- c(myResults, isTRUE(
all.equal(a@front(), b[1 ,])
))
myResults <- c(myResults, isTRUE(all.equal(a@currIter(),
b[1 ,])))
myResults <- c(myResults, isTRUE(all.equal(a@back(),
b[myRows, ])))
myResults <- c(myResults, isTRUE(all.equal(a@currIter(),
b[myRows, ])))
}

a@startOver()
msg <- capture.output(noMore <- a@currIter())
myResults <- c(myResults, is.null(noMore))
myResults <- c(myResults, grepl("Iterator Initialized. To see the first", msg[1]))
a1 <- b

if (myRows) {
for (i in 1:myRows) {
a1[i, ] <- a@nextIter()
}

myResults <- c(myResults, isTRUE(all.equal(a1, b)))
a@startOver()
num_iters <- if (myRows > 10) 3L else 1L
numTest <- as.integer(myRows / num_iters);

s <- 1L
e <- numTest

for (i in 1:num_iters) {
myResults <- c(myResults, isTRUE(all.equal(a@nextNIter(numTest),
b[s:e, , drop = FALSE])))
s <- e + 1L
e <- e + numTest
}

a@startOver()
myResults <- c(myResults, isTRUE(all.equal(a@nextRemaining(), b)))
msg <- capture.output(noMore <- a@nextIter())
myResults <- c(myResults, is.null(noMore))

if (testRand) {
a@back()
msg <- capture.output(noMore <- a@nextNIter(1))
myResults <- c(myResults, is.null(noMore))
myResults <- c(myResults, "No more results." == msg[1])
msg <- capture.output(noMore <- a@currIter())
myResults <- c(myResults, "No more results." == msg[1])

a@startOver()
a@back()
msg <- capture.output(noMore <- a@nextRemaining())
myResults <- c(myResults, is.null(noMore))
myResults <- c(myResults, "No more results." == msg[1])

a@startOver()
a@back()
msg <- capture.output(noMore <- a@nextIter())
myResults <- c(myResults, is.null(noMore))
myResults <- c(myResults, "No more results." == msg[1])

samp <- sample(myRows, numTest)
myResults <- c(myResults, isTRUE(all.equal(a[[samp]], b[samp, ])))
one_samp <- sample(myRows, 1)
myResults <- c(myResults, isTRUE(all.equal(a[[one_samp]], b[one_samp, ])))
}
} else {
a@startOver()
msg <- capture.output(noMore <- a@nextNIter(1))
myResults <- c(myResults, is.null(noMore))
myResults <- c(myResults, "No more results." == msg[1])
msg <- capture.output(noMore <- a@currIter())
myResults <- c(myResults, "No more results." == msg[1])

a@startOver()
msg <- capture.output(noMore <- a@nextIter())
myResults <- c(myResults, is.null(noMore))
myResults <- c(myResults, "No more results." == msg[1])

a@startOver()
msg <- capture.output(noMore <- a@nextRemaining())
myResults <- c(myResults, is.null(noMore))
myResults <- c(myResults, "No more results." == msg[1])
}

rm(a, a1, b)
gc()
all(myResults)
}

expect_true(comboGroupsClassTest(12, 3))
# expect_true(comboGroupsClassTest(12, 3, ret = "3Darray"))
expect_true(comboGroupsClassTest(12, grp_sizes = c(3, 3, 6)))
expect_true(comboGroupsClassTest(12, grp_sizes = c(3, 4, 5)))
expect_true(comboGroupsClassTest(11, grp_sizes = c(1, 2, 2, 3, 3)))
expect_true(comboGroupsClassTest(4, grp_sizes = c(1, 1, 1, 1)))
expect_true(comboGroupsClassTest(4, grp_sizes = c(1, 1, 2)))
expect_true(comboGroupsClassTest(3, grp_sizes = c(1, 2)))
expect_true(comboGroupsClassTest(3, 1))
expect_true(comboGroupsClassTest(1, 1))

##******** BIG TESTS *********##
comboGroupsClassBigZTest <- function(
v_pass, n_grps = NULL, grp_sizes = NULL,
ret = "matrix", lenCheck = 1000, testRand = TRUE
) {

myResults <- vector(mode = "logical")

myRows <- comboGroupsCount(v_pass, n_grps, grp_sizes)
a <- comboGroupsIter(v_pass, n_grps, grp_sizes, ret)
b1 <- comboGroups(v_pass, n_grps, grp_sizes, ret, upper = lenCheck)
b2 <- comboGroups(v_pass, n_grps, grp_sizes, ret,
lower = gmp::sub.bigz(myRows, lenCheck - 1))

myResults <- c(myResults, isTRUE(all.equal(
a@summary()$totalResults, myRows)
))

if (length(v_pass) == 1) {
myResults <- c(myResults, isTRUE(
all.equal(v_pass, length(a@sourceVector()))
))
} else {
myResults <- c(myResults, isTRUE(
all.equal(sort(v_pass), a@sourceVector())
))
}

myResults <- c(myResults, isTRUE(
all.equal(a@front(), b1[1 ,])
))
myResults <- c(myResults, isTRUE(all.equal(a@currIter(),
b1[1 ,])))
myResults <- c(myResults, isTRUE(all.equal(a@back(),
b2[lenCheck, ])))
myResults <- c(myResults, isTRUE(all.equal(a@currIter(),
b2[lenCheck, ])))

a@startOver()
a1 <- b1

for (i in 1:lenCheck) {
a1[i, ] <- a@nextIter()
}

myResults <- c(myResults, isTRUE(all.equal(a1, b1)))
a@startOver()
numTest <- as.integer(lenCheck / 3);
s <- 1L
e <- numTest

for (i in 1:3) {
myResults <- c(myResults, isTRUE(all.equal(a@nextNIter(numTest),
b1[s:e, ])))
s <- e + 1L
e <- e + numTest
}

a@startOver()
a[[gmp::sub.bigz(myRows, lenCheck)]]
myResults <- c(myResults, isTRUE(all.equal(a@nextRemaining(), b2)))

t <- capture.output(a@nextIter())
myResults <- c(myResults, is.null(a@nextIter()))
myResults <- c(myResults, is.null(a@nextNIter(1)))
myResults <- c(myResults, is.null(a@nextRemaining()))

samp1 <- sample(lenCheck, 5)
samp2 <- gmp::sub.bigz(myRows, lenCheck) + gmp::as.bigz(samp1)
myResults <- c(myResults, isTRUE(all.equal(a[[samp1]], b1[samp1, ])))
myResults <- c(myResults, isTRUE(all.equal(a[[samp2]], b2[samp1, ])))
rm(a, a1, b1, b2)
gc()
all(myResults)
}

})

0 comments on commit e49a928

Please sign in to comment.