Skip to content

Commit

Permalink
sycl - consistency and correctness fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Sep 11, 2024
1 parent dd9ae86 commit a1e876b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 42 deletions.
91 changes: 50 additions & 41 deletions backends/sycl-ref/ceed-sycl-ref-operator.sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ static int CeedOperatorDestroy_Sycl(CeedOperator op) {
}
CeedCallBackend(CeedFree(&impl->q_vecs_out));

// QFunction assembly dataf
// QFunction assembly data
for (CeedInt i = 0; i < impl->num_active_in; i++) {
CeedCallBackend(CeedVectorDestroy(&impl->qf_active_in[i]));
}
Expand Down Expand Up @@ -132,15 +132,15 @@ static int CeedOperatorSetupFields_Sycl(CeedQFunction qf, CeedOperator op, bool
for (CeedInt i = 0; i < num_fields; i++) {
CeedEvalMode eval_mode;
CeedVector vec;
CeedElemRestriction rstr;
CeedElemRestriction elem_rstr;
CeedBasis basis;

CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_fields[i], &eval_mode));

is_strided = false;
skip_restriction = false;
if (eval_mode != CEED_EVAL_WEIGHT) {
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &rstr));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_fields[i], &elem_rstr));

// Check whether this field can skip the element restriction:
// must be passive input, with eval_mode NONE, and have a strided restriction with CEED_STRIDES_BACKEND.
Expand All @@ -153,10 +153,10 @@ static int CeedOperatorSetupFields_Sycl(CeedQFunction qf, CeedOperator op, bool
// Check eval_mode
if (eval_mode == CEED_EVAL_NONE) {
// Check for is_strided restriction
CeedCallBackend(CeedElemRestrictionIsStrided(rstr, &is_strided));
CeedCallBackend(CeedElemRestrictionIsStrided(elem_rstr, &is_strided));
if (is_strided) {
// Check if vector is already in preferred backend ordering
CeedCallBackend(CeedElemRestrictionHasBackendStrides(rstr, &skip_restriction));
CeedCallBackend(CeedElemRestrictionHasBackendStrides(elem_rstr, &skip_restriction));
}
}
}
Expand All @@ -166,9 +166,9 @@ static int CeedOperatorSetupFields_Sycl(CeedQFunction qf, CeedOperator op, bool
// We do not need an E-Vector, but will use the input field vector's data directly in the operator application
e_vecs[i + start_e] = NULL;
} else {
CeedCallBackend(CeedElemRestrictionCreateVector(rstr, NULL, &e_vecs[i + start_e]));
CeedCallBackend(CeedElemRestrictionCreateVector(elem_rstr, NULL, &e_vecs[i + start_e]));
}
CeedCallBackend(CeedElemRestrictionDestroy(&rstr));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
}

switch (eval_mode) {
Expand Down Expand Up @@ -276,11 +276,11 @@ static inline int CeedOperatorSetupInputs_Sycl(CeedInt num_input_fields, CeedQFu
// No restriction for this field; read data directly from vec.
CeedCallBackend(CeedVectorGetArrayRead(vec, CEED_MEM_DEVICE, (const CeedScalar **)&e_data[i]));
} else {
CeedElemRestriction rstr;
CeedElemRestriction elem_rstr;

CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &rstr));
CeedCallBackend(CeedElemRestrictionApply(rstr, CEED_NOTRANSPOSE, vec, impl->e_vecs[i], request));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_input_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_NOTRANSPOSE, vec, impl->e_vecs[i], request));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
CeedCallBackend(CeedVectorGetArrayRead(impl->e_vecs[i], CEED_MEM_DEVICE, (const CeedScalar **)&e_data[i]));
}
}
Expand All @@ -297,7 +297,6 @@ static inline int CeedOperatorInputBasis_Sycl(CeedInt num_elem, CeedQFunctionFie
CeedOperator_Sycl *impl) {
for (CeedInt i = 0; i < num_input_fields; i++) {
CeedEvalMode eval_mode;
CeedBasis basis;

// Skip active input
if (skip_active) {
Expand All @@ -316,11 +315,14 @@ static inline int CeedOperatorInputBasis_Sycl(CeedInt num_elem, CeedQFunctionFie
CeedCallBackend(CeedVectorSetArray(impl->q_vecs_in[i], CEED_MEM_DEVICE, CEED_USE_POINTER, e_data[i]));
break;
case CEED_EVAL_INTERP:
case CEED_EVAL_GRAD:
case CEED_EVAL_GRAD: {
CeedBasis basis;

CeedCallBackend(CeedOperatorFieldGetBasis(op_input_fields[i], &basis));
CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_NOTRANSPOSE, eval_mode, impl->e_vecs[i], impl->q_vecs_in[i]));
CeedCallBackend(CeedBasisDestroy(&basis));
break;
}
case CEED_EVAL_WEIGHT:
break; // No action
case CEED_EVAL_DIV:
Expand Down Expand Up @@ -405,32 +407,35 @@ static int CeedOperatorApplyAdd_Sycl(CeedOperator op, CeedVector in_vec, CeedVec

// Output basis apply if needed
for (CeedInt i = 0; i < num_output_fields; i++) {
CeedElemRestriction rstr;
CeedBasis basis;
CeedElemRestriction elem_rstr;

// Get elem_size, eval_mode, size
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &rstr));
CeedCallBackend(CeedElemRestrictionGetElementSize(rstr, &elem_size));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
CeedCallBackend(CeedElemRestrictionGetElementSize(elem_rstr, &elem_size));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
CeedCallBackend(CeedQFunctionFieldGetSize(qf_output_fields[i], &size));
// Basis action
switch (eval_mode) {
case CEED_EVAL_NONE:
break;
case CEED_EVAL_INTERP:
case CEED_EVAL_GRAD:
case CEED_EVAL_GRAD: {
CeedBasis basis;

CeedCallBackend(CeedOperatorFieldGetBasis(op_output_fields[i], &basis));
CeedCallBackend(CeedBasisApply(basis, num_elem, CEED_TRANSPOSE, eval_mode, impl->q_vecs_out[i], impl->e_vecs[i + impl->num_e_in]));
CeedCallBackend(CeedBasisDestroy(&basis));
break;
}
// LCOV_EXCL_START
case CEED_EVAL_WEIGHT:
case CEED_EVAL_WEIGHT: {
Ceed ceed;

CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
return CeedError(ceed, CEED_ERROR_BACKEND, "CEED_EVAL_WEIGHT cannot be an output evaluation mode");
break; // Should not occur
}
case CEED_EVAL_DIV:
case CEED_EVAL_CURL: {
Ceed ceed;
Expand All @@ -448,7 +453,7 @@ static int CeedOperatorApplyAdd_Sycl(CeedOperator op, CeedVector in_vec, CeedVec
bool is_active;
CeedEvalMode eval_mode;
CeedVector vec;
CeedElemRestriction rstr;
CeedElemRestriction elem_rstr;

// Restore evec
CeedCallBackend(CeedQFunctionFieldGetEvalMode(qf_output_fields[i], &eval_mode));
Expand All @@ -458,11 +463,11 @@ static int CeedOperatorApplyAdd_Sycl(CeedOperator op, CeedVector in_vec, CeedVec
// Restrict
CeedCallBackend(CeedOperatorFieldGetVector(op_output_fields[i], &vec));
is_active = vec == CEED_VECTOR_ACTIVE;
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &rstr));
CeedCallBackend(CeedOperatorFieldGetElemRestriction(op_output_fields[i], &elem_rstr));
if (is_active) vec = out_vec;
CeedCallBackend(CeedElemRestrictionApply(rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_inputs], vec, request));
CeedCallBackend(CeedElemRestrictionApply(elem_rstr, CEED_TRANSPOSE, impl->e_vecs[i + impl->num_e_in], vec, request));
if (!is_active) CeedCallBackend(CeedVectorDestroy(&vec));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
}

// Restore input arrays
Expand All @@ -473,8 +478,8 @@ static int CeedOperatorApplyAdd_Sycl(CeedOperator op, CeedVector in_vec, CeedVec
//------------------------------------------------------------------------------
// Core code for assembling linear QFunction
//------------------------------------------------------------------------------
static inline int CeedOperatorLinearAssembleQFunctionCore_Sycl(CeedOperator op, bool build_objects, CeedVector *assembled, CeedElemRestriction *rstr,
CeedRequest *request) {
static inline int CeedOperatorLinearAssembleQFunctionCore_Sycl(CeedOperator op, bool build_objects, CeedVector *assembled,
CeedElemRestriction *elem_rstr, CeedRequest *request) {
Ceed ceed, ceed_parent;
CeedSize q_size;
CeedInt num_active_in, num_active_out, Q, num_elem, num_input_fields, num_output_fields, size;
Expand Down Expand Up @@ -556,7 +561,7 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Sycl(CeedOperator op,

// Create output restriction
CeedCallBackend(CeedElemRestrictionCreateStrided(ceed_parent, num_elem, Q, num_active_in * num_active_out,
num_active_in * num_active_out * num_elem * Q, strides, rstr));
num_active_in * num_active_out * num_elem * Q, strides, elem_rstr));
// Create assembled vector
CeedCallBackend(CeedVectorCreate(ceed_parent, l_size, assembled));
}
Expand Down Expand Up @@ -613,15 +618,16 @@ static inline int CeedOperatorLinearAssembleQFunctionCore_Sycl(CeedOperator op,
//------------------------------------------------------------------------------
// Assemble Linear QFunction
//------------------------------------------------------------------------------
static int CeedOperatorLinearAssembleQFunction_Sycl(CeedOperator op, CeedVector *assembled, CeedElemRestriction *rstr, CeedRequest *request) {
return CeedOperatorLinearAssembleQFunctionCore_Sycl(op, true, assembled, rstr, request);
static int CeedOperatorLinearAssembleQFunction_Sycl(CeedOperator op, CeedVector *assembled, CeedElemRestriction *elem_rstr, CeedRequest *request) {
return CeedOperatorLinearAssembleQFunctionCore_Sycl(op, true, assembled, elem_rstr, request);
}

//------------------------------------------------------------------------------
// Update Assembled Linear QFunction
//------------------------------------------------------------------------------
static int CeedOperatorLinearAssembleQFunctionUpdate_Sycl(CeedOperator op, CeedVector assembled, CeedElemRestriction rstr, CeedRequest *request) {
return CeedOperatorLinearAssembleQFunctionCore_Sycl(op, false, &assembled, &rstr, request);
static int CeedOperatorLinearAssembleQFunctionUpdate_Sycl(CeedOperator op, CeedVector assembled, CeedElemRestriction elem_rstr,
CeedRequest *request) {
return CeedOperatorLinearAssembleQFunctionCore_Sycl(op, false, &assembled, &elem_rstr, request);
}

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -893,22 +899,25 @@ static int CeedOperatorLinearDiagonal_Sycl(sycl::queue &sycl_queue, const bool i
// Assemble diagonal common code
//------------------------------------------------------------------------------
static inline int CeedOperatorAssembleDiagonalCore_Sycl(CeedOperator op, CeedVector assembled, CeedRequest *request, const bool is_point_block) {
Ceed ceed;
Ceed_Sycl *sycl_data;
CeedInt num_elem;
CeedScalar *elem_diag_array;
const CeedScalar *assembled_qf_array;
CeedVector assembled_qf = NULL;
CeedElemRestriction rstr = NULL;
CeedOperator_Sycl *impl;
Ceed ceed;
Ceed_Sycl *sycl_data;
CeedInt num_elem;
CeedScalar *elem_diag_array;
const CeedScalar *assembled_qf_array;
CeedVector assembled_qf = NULL;
CeedOperator_Sycl *impl;

CeedCallBackend(CeedOperatorGetCeed(op, &ceed));
CeedCallBackend(CeedOperatorGetData(op, &impl));
CeedCallBackend(CeedGetData(ceed, &sycl_data));

// Assemble QFunction
CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &rstr, request));
CeedCallBackend(CeedElemRestrictionDestroy(&rstr));
{
CeedElemRestriction elem_rstr = NULL;

CeedCallBackend(CeedOperatorLinearAssembleQFunctionBuildOrUpdate(op, &assembled_qf, &elem_rstr, request));
CeedCallBackend(CeedElemRestrictionDestroy(&elem_rstr));
}

// Setup
if (!impl->diag) {
Expand Down
1 change: 0 additions & 1 deletion backends/sycl-ref/ceed-sycl-ref.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ typedef struct {
CeedVector *q_vecs_out; // Output Q-vectors needed to apply operator
CeedInt num_e_in;
CeedInt num_e_out;
CeedInt num_inputs, num_outputs;
CeedInt num_active_in, num_active_out;
CeedVector *qf_active_in;
CeedOperatorDiag_Sycl *diag;
Expand Down

0 comments on commit a1e876b

Please sign in to comment.