Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions R/groupingsets.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,48 @@ rollup.data.table = function(x, j, by, .SDcols, id = FALSE, label = NULL, ...) {
groupingsets.data.table(x, by=by, sets=sets, .SDcols=.SDcols, id=id, jj=jj, label=label, enclos = parent.frame())
}

# Helper function to process SDcols
.processSDcols = function(SDcols_sub, SDcols_missing, x, jsub, by, enclos = parent.frame()) {
names_x = names(x)
bysub = substitute(by)
allbyvars = intersect(all.vars(bysub), names_x)
usesSD = any(all.vars(jsub) == ".SD")
if (!usesSD) {
return(NULL)
}
if (SDcols_missing) {
ansvars = sdvars = setdiff(unique(names_x), union(by, allbyvars))
ansvals = match(ansvars, names_x)
return(list(ansvars = ansvars, sdvars = sdvars, ansvals = ansvals))
}
sub.result = SDcols_sub
if (is.call(sub.result) && as.character(sub.result[[1L]]) == "patterns") {
.SDcols = eval_with_cols(sub.result, names_x)
} else {
.SDcols = eval(sub.result, enclos)
}
if (is.character(.SDcols)) {
idx = .SDcols %chin% names_x
if (!all(idx))
stopf("Some items of .SDcols are not column names: %s", toString(.SDcols[!idx]))
ansvars = sdvars = .SDcols
ansvals = match(ansvars, names_x)
} else if (is.numeric(.SDcols)) {
ansvals = as.integer(.SDcols)
if (any(ansvals < 1L | ansvals > length(names_x)))
stopf(".SDcols contains indices out of bounds")
ansvars = sdvars = names_x[ansvals]
} else if (is.logical(.SDcols)) {
if (length(.SDcols) != length(names_x))
stopf(".SDcols is a logical vector of length %d but there are %d columns", length(.SDcols), length(names_x))
ansvals = which(.SDcols)
ansvars = sdvars = names_x[ansvals]
} else {
stopf(".SDcols must be character, numeric, or logical")
}
list(ansvars = ansvars, sdvars = sdvars, ansvals = ansvals)
}

cube = function(x, ...) {
UseMethod("cube")
}
Expand All @@ -29,6 +71,17 @@ cube.data.table = function(x, j, by, .SDcols, id = FALSE, label = NULL, ...) {
stopf("Argument 'id' must be a logical scalar.")
if (missing(j))
stopf("Argument 'j' is required")
# Implementing NSE in cube using the helper, .processSDcols
jj = substitute(j)
sdcols_result = .processSDcols(SDcols_sub = substitute(.SDcols), SDcols_missing = missing(.SDcols), x = x, jsub = jj, by = by, enclos = parent.frame())
if (is.null(sdcols_result)) {
.SDcols = NULL
} else {
ansvars = sdcols_result$ansvars
sdvars = sdcols_result$sdvars
ansvals = sdcols_result$ansvals
.SDcols = sdvars
}
# generate grouping sets for cube - power set: http://stackoverflow.com/a/32187892/2490497
n = length(by)
keepBool = sapply(2L^(seq_len(n)-1L), function(k) rep(c(FALSE, TRUE), times=k, each=((2L^n)/(2L*k))))
Expand Down
35 changes: 35 additions & 0 deletions inst/tests/tests.Rraw
Original file line number Diff line number Diff line change
Expand Up @@ -11503,6 +11503,41 @@ test(1750.34,
character(0)),
id = TRUE)
)
test(1750.35,
cube(dt, j = lapply(.SD, sum), by = c("color","year","status"), id=TRUE, .SDcols=patterns("value")),
groupingsets(dt, j = lapply(.SD, sum), by = c("color","year","status"), .SDcols = "value",
sets = list(c("color","year","status"),
c("color","year"),
c("color","status"),
"color",
c("year","status"),
"year",
"status",
character(0)),
id = TRUE)
)
test(1750.36,
cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c("value", "BADCOL")),
error = "Some items of \\.SDcols are not column names"
)
test(1750.37,
cube(dt, j = lapply(.SD, sum), by = "year", .SDcols = c(TRUE, FALSE)),
error = "\\.SDcols is a logical vector of length"
)
test(1750.38,
cube(dt, j = lapply(.SD, mean), by = "color", .SDcols = c(FALSE, FALSE, FALSE, TRUE, FALSE), id=TRUE),
groupingsets(dt, j = lapply(.SD, mean), by = "color", .SDcols = "amount",
sets = list("color", character(0)),
id = TRUE)
)
test(1750.39,
cube(dt, j = lapply(.SD, sum), by = "color", .SDcols = list("amount")),
error = ".SDcols must be character, numeric, or logical"
)
test(1750.40,
cube(dt, j = lapply(.SD, sum), by = "color", .SDcols = c(1, 99)),
error = "out of bounds"
)
# grouping sets with integer64
if (test_bit64) {
set.seed(26)
Expand Down
Loading