Skip to content

Commit

Permalink
gpu - reuse evecs for AtPoints where able
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Aug 26, 2024
1 parent 3ceddb9 commit 95c081e
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 10 deletions.
56 changes: 51 additions & 5 deletions backends/cuda-ref/ceed-cuda-ref-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -574,14 +574,52 @@ static int CeedOperatorSetupAtPoints_Cuda(CeedOperator op) {
impl->num_inputs = num_input_fields;
impl->num_outputs = num_output_fields;

// Set up infield and outfield e_vecs and q_vecs
// Set up infield and outfield e-vecs and q-vecs
// Infields
CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, true, true, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields,
max_num_points, num_elem));
// Outfields
CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out,
num_input_fields, num_output_fields, max_num_points, num_elem));

// Reuse active e-vecs where able
{
CeedInt num_used = 0;
CeedElemRestriction *rstr_used = NULL;

for (CeedInt i = 0; i < num_input_fields; i++) {
bool is_used = false;
CeedVector vec_i;
CeedElemRestriction rstr_i;

CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i));
if (vec_i != CEED_VECTOR_ACTIVE) continue;
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr_i));
for (CeedInt j = 0; j < num_used; j++) {
if (rstr_i == rstr_used[i]) is_used = true;
}
if (is_used) continue;
num_used++;
if (num_used == 1) CeedCallBackend(CeedCalloc(num_used, &rstr_used));
else CeedCallBackend(CeedRealloc(num_used, &rstr_used));
rstr_used[num_used - 1] = rstr_i;
for (CeedInt j = num_output_fields - 1; j >= 0; j--) {
CeedEvalMode eval_mode;
CeedVector vec_j;
CeedElemRestriction rstr_j;

CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec_j));
if (vec_j != CEED_VECTOR_ACTIVE) continue;
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode));
if (eval_mode == CEED_EVAL_NONE) continue;
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &rstr_j));
if (rstr_i == rstr_j) {
CeedCallBackend(CeedVectorReferenceCopy(impl->e_vecs[i], &impl->e_vecs[j + impl->num_inputs]));
}
}
}
CeedCallBackend(CeedFree(&rstr_used));
}
CeedCallBackend(CeedOperatorSetSetupDone(op));
return CEED_ERROR_SUCCESS;
}
Expand Down Expand Up @@ -684,6 +722,9 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec,
// Q function
CeedCallBackend(CeedQFunctionApply(qf, num_elem * max_num_points, impl->q_vecs_in, impl->q_vecs_out));

// Restore input arrays
CeedCallBackend(CeedOperatorRestoreInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl));

// Output basis apply if needed
for (CeedInt i = 0; i < num_output_fields; i++) {
CeedEvalMode eval_mode;
Expand Down Expand Up @@ -741,9 +782,6 @@ static int CeedOperatorApplyAddAtPoints_Cuda(CeedOperator op, CeedVector in_vec,

CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request));
}

// Restore input arrays
CeedCallBackend(CeedOperatorRestoreInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -868,7 +906,7 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Cuda(CeedOperator op,
CeedCallBackend(CeedQFunctionApply(qf, Q * num_elem, impl->q_vecs_in, impl->q_vecs_out));
}

// Un-set output q_vecs to prevent accidental overwrite of Assembled
// Un-set output q-vecs to prevent accidental overwrite of Assembled
for (CeedInt out = 0; out < num_output_fields; out++) {
CeedVector vec;

Expand Down Expand Up @@ -1595,6 +1633,14 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Cuda(CeedOperator op, C
max_num_points = impl->max_num_points;
for (CeedInt i = 0; i < num_elem; i++) num_points[i] = max_num_points;

// Create separate output e-vecs
for (CeedInt i = 0; i < impl->num_outputs; i++) {
CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i]));
CeedCallBackend(CeedVectorDestroy(&impl->e_vecs[impl->num_inputs + i]));
}
CeedCallBackend(CeedOperatorSetupFields_Cuda(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out,
num_input_fields, num_output_fields, max_num_points, num_elem));

// Input Evecs and Restriction
CeedCallBackend(CeedOperatorSetupInputs_Cuda(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request));

Expand Down
56 changes: 51 additions & 5 deletions backends/hip-ref/ceed-hip-ref-operator.c
Original file line number Diff line number Diff line change
Expand Up @@ -573,14 +573,52 @@ static int CeedOperatorSetupAtPoints_Hip(CeedOperator op) {
impl->num_inputs = num_input_fields;
impl->num_outputs = num_output_fields;

// Set up infield and outfield e_vecs and q_vecs
// Set up infield and outfield e-vecs and q-vecs
// Infields
CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, true, true, impl->skip_rstr_in, NULL, impl->e_vecs, impl->q_vecs_in, 0, num_input_fields,
max_num_points, num_elem));
// Outfields
CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out,
num_input_fields, num_output_fields, max_num_points, num_elem));

// Reuse active e-vecs where able
{
CeedInt num_used = 0;
CeedElemRestriction *rstr_used = NULL;

for (CeedInt i = 0; i < num_input_fields; i++) {
bool is_used = false;
CeedVector vec_i;
CeedElemRestriction rstr_i;

CeedCallBackend(CeedOperatorFieldGetVector(op_input_fields[i], &vec_i));
if (vec_i != CEED_VECTOR_ACTIVE) continue;
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr_i));
for (CeedInt j = 0; j < num_used; j++) {
if (rstr_i == rstr_used[i]) is_used = true;
}
if (is_used) continue;
num_used++;
if (num_used == 1) CeedCallBackend(CeedCalloc(num_used, &rstr_used));
else CeedCallBackend(CeedRealloc(num_used, &rstr_used));
rstr_used[num_used - 1] = rstr_i;
for (CeedInt j = num_output_fields - 1; j >= 0; j--) {
CeedEvalMode eval_mode;
CeedVector vec_j;
CeedElemRestriction rstr_j;

CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[j], &vec_j));
if (vec_j != CEED_VECTOR_ACTIVE) continue;
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[j], &eval_mode));
if (eval_mode == CEED_EVAL_NONE) continue;
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[j], &rstr_j));
if (rstr_i == rstr_j) {
CeedCallBackend(CeedVectorReferenceCopy(impl->e_vecs[i], &impl->e_vecs[j + impl->num_inputs]));
}
}
}
CeedCallBackend(CeedFree(&rstr_used));
}
CeedCallBackend(CeedOperatorSetSetupDone(op));
return CEED_ERROR_SUCCESS;
}
Expand Down Expand Up @@ -683,6 +721,9 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec,
// Q function
CeedCallBackend(CeedQFunctionApply(qf, num_elem * max_num_points, impl->q_vecs_in, impl->q_vecs_out));

// Restore input arrays
CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl));

// Output basis apply if needed
for (CeedInt i = 0; i < num_output_fields; i++) {
CeedEvalMode eval_mode;
Expand Down Expand Up @@ -740,9 +781,6 @@ static int CeedOperatorApplyAddAtPoints_Hip(CeedOperator op, CeedVector in_vec,

CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request));
}

// Restore input arrays
CeedCallBackend(CeedOperatorRestoreInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, false, e_data, impl));
return CEED_ERROR_SUCCESS;
}

Expand Down Expand Up @@ -867,7 +905,7 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Hip(CeedOperator op, b
CeedCallBackend(CeedQFunctionApply(qf, Q * num_elem, impl->q_vecs_in, impl->q_vecs_out));
}

// Un-set output q_vecs to prevent accidental overwrite of Assembled
// Un-set output q-vecs to prevent accidental overwrite of Assembled
for (CeedInt out = 0; out < num_output_fields; out++) {
CeedVector vec;

Expand Down Expand Up @@ -1592,6 +1630,14 @@ static int CeedOperatorLinearAssembleAddDiagonalAtPoints_Hip(CeedOperator op, Ce
max_num_points = impl->max_num_points;
for (CeedInt i = 0; i < num_elem; i++) num_points[i] = max_num_points;

// Create separate output e-vecs
for (CeedInt i = 0; i < impl->num_outputs; i++) {
CeedCallBackend(CeedVectorDestroy(&impl->q_vecs_out[i]));
CeedCallBackend(CeedVectorDestroy(&impl->e_vecs[impl->num_inputs + i]));
}
CeedCallBackend(CeedOperatorSetupFields_Hip(qf, op, false, true, impl->skip_rstr_out, impl->apply_add_basis_out, impl->e_vecs, impl->q_vecs_out,
num_input_fields, num_output_fields, max_num_points, num_elem));

// Input Evecs and Restriction
CeedCallBackend(CeedOperatorSetupInputs_Hip(num_input_fields, qf_input_fields, op_input_fields, NULL, true, e_data, impl, request));

Expand Down

0 comments on commit 95c081e

Please sign in to comment.