Skip to content

Commit

Permalink
Fix groupby when applied to a view (#1545)
Browse files Browse the repository at this point in the history
Added new internal method `DataTable::group(spec, as_view) -> pair<RowIndex, Groupby>`, as a replacement for `DataTable::sortby()`. The new method has more predictable return value: the returned rowindex either applies to the column, or to its source (for a view column), depending on the value of the parameter `as_view`.

Closes #1542
  • Loading branch information
st-pasha authored Jan 14, 2019
1 parent d7e026d commit 0e20f4e
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 33 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- Function `count()` now returns correct result within the `DT[i, j]` expression
with non-trivial `i` (#1316).

- Fixed groupby when it is applied to a Frame with view columns (#1542).


### Changed

Expand Down
32 changes: 23 additions & 9 deletions c/datatable.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,20 @@ class Stats;
class DataTable;
class NameProvider;

struct sort_spec {
size_t col_index;
bool descending;
bool na_last;
bool sort_only;
size_t : 40;

sort_spec(size_t i)
: col_index(i), descending(false), na_last(false), sort_only(false) {}
};

typedef Column* (Column::*colmakerfn)(void) const;
using colvec = std::vector<Column*>;
using intvec = std::vector<size_t>;
using strvec = std::vector<std::string>;
using dtptr = std::unique_ptr<DataTable>;

Expand Down Expand Up @@ -79,29 +91,31 @@ class DataTable {
DataTable(colvec&& cols, const DataTable*);
~DataTable();

void delete_columns(std::vector<size_t>&);
void delete_columns(intvec&);
void delete_all();
void resize_rows(size_t n);
void replace_rowindex(const RowIndex& newri);
void apply_rowindex(const RowIndex&);
void replace_groupby(const Groupby& newgb);
void reify();
void rbind(std::vector<DataTable*>, std::vector<std::vector<size_t>>);
void rbind(std::vector<DataTable*>, std::vector<intvec>);
DataTable* cbind(std::vector<DataTable*>);
DataTable* copy() const;
size_t memory_footprint() const;

/**
* Sort the DataTable by specified columns, and return the corresponding
* RowIndex. The array `colindices` provides the indices of columns to
* sort on. If an index is negative, it indicates that the column must be
* sorted in descending order instead of default ascending.
* sort on.
*
* If `make_groups` is true, then in addition to sorting, the grouping
* information will be computed and stored with the RowIndex.
*/
RowIndex sortby(const std::vector<size_t>& colindices,
Groupby* out_grps) const;
// TODO: remove
RowIndex sortby(const intvec& colindices, Groupby* out_grps) const;

std::pair<RowIndex, Groupby>
group(const std::vector<sort_spec>& spec, bool as_view = false) const;

// Names
const strvec& get_names() const;
Expand All @@ -111,13 +125,13 @@ class DataTable {
void copy_names_from(const DataTable* other);
void set_names_to_default();
void set_names(const py::olist& names_list);
void set_names(const std::vector<std::string>& names_list);
void set_names(const strvec& names_list);
void replace_names(py::odict replacements);
void reorder_names(const std::vector<size_t>& col_indices);
void reorder_names(const intvec& col_indices);

// Key
size_t get_nkeys() const;
void set_key(std::vector<size_t>& col_indices);
void set_key(intvec& col_indices);
void clear_key();
void set_nkeys_unsafe(size_t K);

Expand Down
9 changes: 7 additions & 2 deletions c/expr/by_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,13 @@ void collist_bn::execute(workframe& wf) {
if (ri0) {
throw NotImplError();
}
RowIndex ri = dt0->sortby(indices, &gb);
wf.apply_rowindex(ri);
std::vector<sort_spec> spec;
for (size_t i : indices) {
spec.push_back(sort_spec(i));
}
auto res = dt0->group(spec);
gb = std::move(res.second);
wf.apply_rowindex(res.first);
}


Expand Down
85 changes: 63 additions & 22 deletions c/sort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,17 +428,18 @@ class SortContext {
}


RowIndex get_result(Groupby* out_grps) {
RowIndex res = RowIndex(
arr32_t(n, static_cast<int32_t*>(container_o.release()), true)
);
if (out_grps) {
size_t ngrps = gg.size();
xassert(groups.size() > ngrps);
groups.resize(ngrps + 1);
*out_grps = Groupby(ngrps, groups.to_memoryrange());
}
return res;
RowIndex get_result_rowindex() {
auto data = static_cast<int32_t*>(container_o.release());
return RowIndex(arr32_t(n, data, true));
}


std::pair<RowIndex, Groupby> get_result_groups() {
size_t ng = gg.size();
xassert(groups.size() > ng);
groups.resize(ng + 1);
return std::pair<RowIndex, Groupby>(get_result_rowindex(),
Groupby(ng, groups.to_memoryrange()));
}


Expand Down Expand Up @@ -1173,26 +1174,54 @@ class SortContext {
//==============================================================================
// Main sorting routines
//==============================================================================
using RiGb = std::pair<RowIndex, Groupby>;


RiGb DataTable::group(const std::vector<sort_spec>& spec, bool as_view) const {
size_t nsortcols = spec.size();
Column* col0 = columns[spec[0].col_index];
if (as_view) {
// Check that the sorted columns have consistent rowindices.
for (size_t j = 1; j < nsortcols; ++j) {
xassert(columns[spec[j].col_index]->rowindex() == col0->rowindex());
}
} else {
for (size_t j = 0; j < nsortcols; ++j) {
columns[spec[j].col_index]->reify();
}
}

if (nrows <= 1) {
size_t i = col0->rowindex()[0];
return RiGb(RowIndex(i, nrows, 1),
Groupby::single_group(nrows));
}

SortContext sc(nrows, col0->rowindex(), true);
sc.start_sort(col0);
for (size_t j = 1; j < nsortcols; ++j) {
sc.continue_sort(columns[spec[j].col_index], true);
}
return sc.get_result_groups();
}


/**
* Sort the column, and return its ordering as a RowIndex object. This function
* will choose the most appropriate algorithm for sorting. The data in column
* `col` will not be modified.
*
* The function returns nullptr if there is a runtime error (for example an
* intermediate buffer cannot be allocated).
* Sort the Frame by columns at the specified positions `colindices`, and return
* their ordering as a RowIndex object. The data in the current Frame will not
* be modified.
*/
RowIndex DataTable::sortby(const std::vector<size_t>& colindices,
Groupby* out_grps) const
{
size_t nsortcols = colindices.size();
if (nrows > INT32_MAX) {
throw NotImplError() << "Cannot sort a datatable with " << nrows << " rows";
throw NotImplError() << "Cannot sort a Frame with " << nrows << " rows";
}
if (rowindex.isarr64() || rowindex.size() > INT32_MAX ||
(rowindex.max() > INT32_MAX && rowindex.max() != RowIndex::NA)) {
throw NotImplError() << "Cannot sort a datatable which is based on a "
"datatable with >2**31 rows";
throw NotImplError() << "Cannot sort a Frame which is a view on another "
"Frame with more than 2**31 rows";
}
// TODO: fix for the multi-rowindex case (#1188)
// A frame can be sorted by columns col1, ..., colN if and only if all these
Expand Down Expand Up @@ -1224,7 +1253,13 @@ RowIndex DataTable::sortby(const std::vector<size_t>& colindices,
sc.continue_sort(columns[colindices[j]],
(out_grps != nullptr) || (j < nsortcols - 1));
}
return sc.get_result(out_grps);
if (out_grps) {
auto res = sc.get_result_groups();
*out_grps = std::move(res.second);
return res.first;
} else {
return sc.get_result_rowindex();
}
}


Expand All @@ -1249,5 +1284,11 @@ RowIndex Column::sort(Groupby* out_grps) const {
}
SortContext sc(nrows, rowindex(), (out_grps != nullptr));
sc.start_sort(this);
return sc.get_result(out_grps);
if (out_grps) {
auto res = sc.get_result_groups();
*out_grps = std::move(res.second);
return res.first;
} else {
return sc.get_result_rowindex();
}
}
19 changes: 19 additions & 0 deletions tests/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,22 @@ def test_groupby_multi_large(seed):
DT1 = DT0[:, sum(f.D), by(f.A, f.B, f.C)]
DT2 = dt.Frame(grouped)
assert same_iterables(DT1.to_list(), DT2.to_list())


def test_groupby_on_view():
# See issue #1542
DT = dt.Frame(A=[1, 2, 3, 1, 2, 3],
B=[3, 6, 2, 4, 3, 1],
C=['b', 'd', 'b', 'b', 'd', 'b'])
V = DT[f.A != 1, :]
assert V.internal.isview
assert V.shape == (4, 3)
assert V.to_dict() == {'A': [2, 3, 2, 3],
'B': [6, 2, 3, 1],
'C': ['d', 'b', 'd', 'b']}
RES = V[:, max(f.B), by(f.C)]
assert RES.shape == (2, 2)
assert RES.to_dict() == {'C': ['b', 'd'],
'C0': [2, 6]}


0 comments on commit 0e20f4e

Please sign in to comment.