Skip to content

Commit

Permalink
Close tdbMatrix and tdbMatrixWithIds Array's when we have nothing lef…
Browse files Browse the repository at this point in the history
…t to read (#466)
  • Loading branch information
jparismorgan authored Jul 26, 2024
1 parent 5db39ff commit 7adba24
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 4 deletions.
12 changes: 10 additions & 2 deletions src/include/detail/linalg/tdb_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ class tdbBlockedMatrix : public MatrixBase {
// size_t pending_row_offset{0};
// size_t pending_col_offset{0};

size_t get_elements_to_load() const {
return std::min(load_blocksize_, last_col_ - last_resident_col_);
}

public:
tdbBlockedMatrix(tdbBlockedMatrix&& rhs) = default;

Expand Down Expand Up @@ -324,11 +328,11 @@ class tdbBlockedMatrix : public MatrixBase {
}

size_t dimension = last_row_ - first_row_;
auto elements_to_load =
std::min(load_blocksize_, last_col_ - last_resident_col_);
auto elements_to_load = get_elements_to_load();

// Return if we're at the end
if (elements_to_load == 0 || dimension == 0) {
array_->close();
return false;
}

Expand Down Expand Up @@ -363,6 +367,10 @@ class tdbBlockedMatrix : public MatrixBase {
throw std::runtime_error("Query status is not complete");
}

if (get_elements_to_load() == 0) {
array_->close();
}

num_loads_++;
return true;
}
Expand Down
5 changes: 5 additions & 0 deletions src/include/detail/linalg/tdb_matrix_with_ids.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ class tdbBlockedMatrixWithIds
bool load() {
scoped_timer _{tdb_func__ + " " + this->ids_uri_};
if (!Base::load()) {
ids_array_->close();
return false;
}

Expand Down Expand Up @@ -214,6 +215,10 @@ class tdbBlockedMatrixWithIds
throw std::runtime_error("Query status for IDs is not complete");
}

if (this->get_elements_to_load() == 0) {
ids_array_->close();
}

return true;
}
}; // tdbBlockedMatrixWithIds
Expand Down
5 changes: 4 additions & 1 deletion src/include/test/unit_tdb_matrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ TEMPLATE_TEST_CASE("constructors", "[tdb_matrix]", float, uint8_t) {
write_matrix(ctx, X, tmp_matrix_uri);

auto Y = tdbColMajorMatrix<TestType>(ctx, tmp_matrix_uri);
Y.load();
CHECK(Y.load() == true);
for (int i = 0; i < 5; ++i) {
CHECK(Y.load() == false);
}

auto Z = tdbColMajorMatrix<TestType>(std::move(Y));

Expand Down
5 changes: 4 additions & 1 deletion src/include/test/unit_tdb_matrix_with_ids.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@ TEMPLATE_TEST_CASE(

auto Y = tdbColMajorMatrixWithIds<TestType, TestType>(
ctx, tmp_matrix_uri, tmp_ids_uri);
Y.load();
CHECK(Y.load() == true);
for (int i = 0; i < 5; ++i) {
CHECK(Y.load() == false);
}
CHECK(num_vectors(Y) == num_vectors(X));
CHECK(dimensions(Y) == dimensions(X));
CHECK(std::equal(
Expand Down

0 comments on commit 7adba24

Please sign in to comment.