Skip to content

Commit

Permalink
Unit tests for Gram-Schmidt methods on CPU backend (#132)
Browse files Browse the repository at this point in the history
* GS unit tests for low-synch cpu version

* Add HIP tests for Gram-Schmidt methods.

---------

Co-authored-by: kswirydo <kasia.swirydowicz@pnnl.gov>
  • Loading branch information
pelesh and kswirydo authored Dec 21, 2023
1 parent afefdd9 commit ddeb557
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 31 deletions.
16 changes: 11 additions & 5 deletions resolve/GramSchmidt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,14 @@ namespace ReSolve
vec_w_->setData(V->getVectorData(i + 1, memspace), memspace);
vec_rv_->setCurrentSize(i + 1);

vector_handler_->massDot2Vec(n, V, i, vec_v_, vec_rv_, memspace);
// vector_handler_->massDot2Vec(n, V, i + 1, vec_v_, vec_rv_, memspace);
vector_handler_->massDot2Vec(n, V, i + 1, vec_v_, vec_rv_, memspace);
vec_rv_->setDataUpdated(memspace);
vec_rv_->copyData(memspace, memory::HOST);

vec_rv_->deepCopyVectorData(&h_L_[idxmap(i, 0, num_vecs_ + 1)], 0, memory::HOST);
h_rv_ = vec_rv_->getVectorData(1, memory::HOST);

for(int j=0; j<=i; ++j) {
H[ idxmap(i, j, num_vecs_ + 1) ] = 0.0;
}
Expand All @@ -218,7 +219,7 @@ namespace ReSolve
} // for j
vec_Hcolumn_->setCurrentSize(i + 1);
vec_Hcolumn_->update(&H[ idxmap(i, 0, num_vecs_ + 1)], memory::HOST, memspace);
vector_handler_->massAxpy(n, vec_Hcolumn_, i, V, vec_w_, memspace);
vector_handler_->massAxpy(n, vec_Hcolumn_, i + 1, V, vec_w_, memspace);

// normalize (second synch)
t = vector_handler_->dot(vec_w_, vec_w_, memspace);
Expand All @@ -228,6 +229,11 @@ namespace ReSolve
if(fabs(t) > EPSILON) {
t = 1.0 / t;
vector_handler_->scal(&t, vec_w_, memspace);
for (int ii=0; ii<=i; ++ii)
{
vec_v_->setData(V->getVectorData(ii, memspace), memspace);
vec_w_->setData(V->getVectorData(i + 1, memspace), memspace);
}
} else {
assert(0 && "Iterative refinement failed, Krylov vector with ZERO norm\n");
return -1;
Expand All @@ -240,7 +246,7 @@ namespace ReSolve
vec_w_->setData(V->getVectorData(i + 1, memspace), memspace);
vec_rv_->setCurrentSize(i + 1);

vector_handler_->massDot2Vec(n, V, i, vec_v_, vec_rv_, memspace);
vector_handler_->massDot2Vec(n, V, i + 1, vec_v_, vec_rv_, memspace);
vec_rv_->setDataUpdated(memspace);
vec_rv_->copyData(memspace, memory::HOST);

Expand Down Expand Up @@ -290,7 +296,7 @@ namespace ReSolve
vec_Hcolumn_->setCurrentSize(i + 1);
vec_Hcolumn_->update(&H[ idxmap(i, 0, num_vecs_ + 1)], memory::HOST, memspace);

vector_handler_->massAxpy(n, vec_Hcolumn_, i, V, vec_w_, memspace);
vector_handler_->massAxpy(n, vec_Hcolumn_, i + 1, V, vec_w_, memspace);
// normalize (second synch)
t = vector_handler_->dot(vec_w_, vec_w_, memspace);
//set the last entry in Hessenberg matrix
Expand Down
4 changes: 2 additions & 2 deletions resolve/cuda/cudaKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ namespace ReSolve {
const real_type* mvec,
real_type* result)
{
kernels::MassIPTwoVec<<<i + 1, 1024>>>(vec1, vec2, mvec, result, i + 1, n);
kernels::MassIPTwoVec<<<i, 1024>>>(vec1, vec2, mvec, result, i, n);
}

/**
Expand All @@ -421,7 +421,7 @@ namespace ReSolve {
*/
void mass_axpy(index_type n, index_type i, const real_type* x, real_type* y, const real_type* alpha)
{
kernels::massAxpy3<<<(n + 384 - 1) / 384, 384>>>(n, i + 1, x, y, alpha);
kernels::massAxpy3<<<(n + 384 - 1) / 384, 384>>>(n, i, x, y, alpha);
}

/**
Expand Down
4 changes: 2 additions & 2 deletions resolve/hip/hipKernels.hip
Original file line number Diff line number Diff line change
Expand Up @@ -524,15 +524,15 @@ namespace ReSolve {
real_type* result)
{
hipLaunchKernelGGL(kernels::MassIPTwoVec_kernel,
dim3(i + 1),
dim3(i),
dim3(1024),
0,
0,
vec1,
vec2,
mvec,
result,
i + 1,
i,
n);
}

Expand Down
29 changes: 20 additions & 9 deletions resolve/vector/VectorHandlerCpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,21 +220,32 @@ namespace ReSolve {
*/
void VectorHandlerCpu::massDot2Vec(index_type size,
vector::Vector* V,
index_type k,
index_type q,
vector::Vector* x,
vector::Vector* res)
{
real_type* res_data = res->getData(memory::HOST);
real_type* x_data = x->getData(memory::HOST);
real_type* V_data = V->getData(memory::HOST);
index_type i, j;
real_type* x_data = x->getData(memory::HOST);
real_type* V_data = V->getData(memory::HOST);

for (i = 0; i < k; ++i) {
real_type c0 = 0.0;
real_type cq = 0.0;

for (index_type i = 0; i < q; ++i) {
res_data[i] = 0.0;
res_data[i + k] = 0.0;
for (j = 0; j < size; ++j) {
res_data[i] += V_data[i * size + j] * x_data[j];
res_data[i + k] += V_data[i * size + j] * x_data[j + size];
res_data[i + q] = 0.0;

// Make sure we don't accumulate round-off errors
for (index_type j = 0; j < size; ++j) {
real_type y0 = (V_data[i * size + j] * x_data[j]) - c0;
real_type yq = (V_data[i * size + j] * x_data[j + size]) - cq;
real_type t0 = res_data[i] + y0;
real_type tq = res_data[i + q] + yq;
c0 = (t0 - res_data[i] ) - y0;
cq = (tq - res_data[i + q]) - yq;

res_data[i] = t0;
res_data[i + q] = tq;
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions resolve/vector/VectorHandlerCuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,12 @@ namespace ReSolve {
CUBLAS_OP_N,
size, // m
1, // n
k + 1, // k
k, // k
&MINUSONE, // alpha
x->getData(memory::DEVICE), // A
size, // lda
alpha->getData(memory::DEVICE), // B
k + 1, // ldb
k, // ldb
&ONE,
y->getData(memory::DEVICE), // c
size); // ldc
Expand Down Expand Up @@ -243,7 +243,7 @@ namespace ReSolve {
cublasDgemm(handle_cublas,
CUBLAS_OP_T,
CUBLAS_OP_N,
k + 1, //m
k, //m
2, //n
size, //k
&ONE, //alpha
Expand All @@ -253,7 +253,7 @@ namespace ReSolve {
size, //ldb
&ZERO,
res->getData(memory::DEVICE), //c
k + 1); //ldc
k); //ldc
}
}

Expand Down
15 changes: 6 additions & 9 deletions tests/unit/vector/GramSchmidtTests.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,10 @@ namespace ReSolve {
TestOutcome GramSchmidtConstructor()
{
TestStatus status;
// status.skipTest();

// GramSchmidt gs1;
// status *= (gs1.getVariant() == GramSchmidt::mgs);
// status *= (gs1.getL() == nullptr);
// status *= !gs1.isSetupComplete();

VectorHandler vh;
GramSchmidt gs2(&vh, GramSchmidt::mgs_pm);
status *= (gs2.getVariant() == GramSchmidt::mgs_pm);
// status *= (gs1.getL() == nullptr);
// status *= !gs1.isSetupComplete();

return status.report(__func__);
}
Expand Down Expand Up @@ -140,6 +132,12 @@ namespace ReSolve {
LinAlgWorkspaceCUDA* workspace = new LinAlgWorkspaceCUDA();
workspace->initializeHandles();
return new VectorHandler(workspace);
#endif
#ifdef RESOLVE_USE_HIP
} else if (memspace_ == "hip") {
LinAlgWorkspaceHIP* workspace = new LinAlgWorkspaceHIP();
workspace->initializeHandles();
return new VectorHandler(workspace);
#endif
} else {
std::cout << "ReSolve not built with support for memory space " << memspace_ << "\n";
Expand Down Expand Up @@ -167,7 +165,6 @@ namespace ReSolve {
a->update(x->getVectorData(i, ms), ms, memory::HOST);
b->update(x->getVectorData(j, ms), ms, memory::HOST);
ip = handler->dot(a, b, memory::HOST);

if ( (i != j) && (abs(ip) > 1e-14)) {
status = false;
std::cout << "Vectors " << i << " and " << j << " are not orthogonal!"
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/vector/runGramSchmidtTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,30 @@ int main(int, char**)
}
#endif

#ifdef RESOLVE_USE_HIP
{
std::cout << "Running tests with HIP backend:\n";
ReSolve::tests::GramSchmidtTests test("hip");

result += test.GramSchmidtConstructor();
result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::cgs2);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_two_synch);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_pm);
std::cout << "\n";
}
#endif

{
std::cout << "Running tests on the CPU:\n";
ReSolve::tests::GramSchmidtTests test("cpu");

result += test.GramSchmidtConstructor();
result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::cgs2);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_two_synch);
result += test.orthogonalize(5000, ReSolve::GramSchmidt::mgs_pm);
std::cout << "\n";
}
return result.summary();
}

0 comments on commit ddeb557

Please sign in to comment.