Skip to content

Commit

Permalink
[mlir][openacc][NFC] Cleanup hasOnly functions for device_type support (
Browse files Browse the repository at this point in the history
#78800)

Just a cleanup for all the `has.*Only()` function to avoid code
duplication
  • Loading branch information
clementval authored Jan 22, 2024
1 parent b5df6a9 commit ee6199c
Showing 1 changed file with 49 additions and 101 deletions.
150 changes: 49 additions & 101 deletions mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,41 @@ void OpenACCDialect::initialize() {
*getContext());
}

//===----------------------------------------------------------------------===//
// device_type support helpers
//===----------------------------------------------------------------------===//

static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
return true;
return false;
}

static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
mlir::acc::DeviceType deviceType) {
if (!hasDeviceTypeValues(arrayAttr))
return false;

for (auto attr : *arrayAttr) {
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
if (deviceTypeAttr.getValue() == deviceType)
return true;
}

return false;
}

static void printDeviceTypes(mlir::OpAsmPrinter &p,
std::optional<mlir::ArrayAttr> deviceTypes) {
if (!hasDeviceTypeValues(deviceTypes))
return;

p << "[";
llvm::interleaveComma(*deviceTypes, p,
[&](mlir::Attribute attr) { p << attr; });
p << "]";
}

//===----------------------------------------------------------------------===//
// DataBoundsOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -722,11 +757,7 @@ bool acc::ParallelOp::hasAsyncOnly() {
}

bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
if (auto arrayAttr = getAsyncOnly()) {
if (findSegment(*arrayAttr, deviceType))
return true;
}
return false;
return hasDeviceType(getAsyncOnly(), deviceType);
}

mlir::Value acc::ParallelOp::getAsyncValue() {
Expand Down Expand Up @@ -789,11 +820,7 @@ bool acc::ParallelOp::hasWaitOnly() {
}

bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
if (auto arrayAttr = getWaitOnly()) {
if (findSegment(*arrayAttr, deviceType))
return true;
}
return false;
return hasDeviceType(getWaitOnly(), deviceType);
}

mlir::Operation::operand_range ParallelOp::getWaitValues() {
Expand Down Expand Up @@ -1033,23 +1060,6 @@ static ParseResult parseDeviceTypeOperandsWithKeywordOnly(
return success();
}

static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
return true;
return false;
}

static void printDeviceTypes(mlir::OpAsmPrinter &p,
std::optional<mlir::ArrayAttr> deviceTypes) {
if (!hasDeviceTypeValues(deviceTypes))
return;

p << "[";
llvm::interleaveComma(*deviceTypes, p,
[&](mlir::Attribute attr) { p << attr; });
p << "]";
}

static void printDeviceTypeOperandsWithKeywordOnly(
mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands,
mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
Expand Down Expand Up @@ -1093,11 +1103,7 @@ bool acc::SerialOp::hasAsyncOnly() {
}

bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
if (auto arrayAttr = getAsyncOnly()) {
if (findSegment(*arrayAttr, deviceType))
return true;
}
return false;
return hasDeviceType(getAsyncOnly(), deviceType);
}

mlir::Value acc::SerialOp::getAsyncValue() {
Expand All @@ -1114,11 +1120,7 @@ bool acc::SerialOp::hasWaitOnly() {
}

bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
if (auto arrayAttr = getWaitOnly()) {
if (findSegment(*arrayAttr, deviceType))
return true;
}
return false;
return hasDeviceType(getWaitOnly(), deviceType);
}

mlir::Operation::operand_range SerialOp::getWaitValues() {
Expand Down Expand Up @@ -1177,11 +1179,7 @@ bool acc::KernelsOp::hasAsyncOnly() {
}

bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
if (auto arrayAttr = getAsyncOnly()) {
if (findSegment(*arrayAttr, deviceType))
return true;
}
return false;
return hasDeviceType(getAsyncOnly(), deviceType);
}

mlir::Value acc::KernelsOp::getAsyncValue() {
Expand Down Expand Up @@ -1228,11 +1226,7 @@ bool acc::KernelsOp::hasWaitOnly() {
}

bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
if (auto arrayAttr = getWaitOnly()) {
if (findSegment(*arrayAttr, deviceType))
return true;
}
return false;
return hasDeviceType(getWaitOnly(), deviceType);
}

mlir::Operation::operand_range KernelsOp::getWaitValues() {
Expand Down Expand Up @@ -1646,33 +1640,21 @@ Value LoopOp::getDataOperand(unsigned i) {
bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }

bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
if (auto arrayAttr = getAuto_()) {
if (findSegment(*arrayAttr, deviceType))
return true;
}
return false;
return hasDeviceType(getAuto_(), deviceType);
}

bool LoopOp::hasIndependent() {
return hasIndependent(mlir::acc::DeviceType::None);
}

bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
if (auto arrayAttr = getIndependent()) {
if (findSegment(*arrayAttr, deviceType))
return true;
}
return false;
return hasDeviceType(getIndependent(), deviceType);
}

bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }

bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
if (auto arrayAttr = getSeq()) {
if (findSegment(*arrayAttr, deviceType))
return true;
}
return false;
return hasDeviceType(getSeq(), deviceType);
}

mlir::Value LoopOp::getVectorValue() {
Expand All @@ -1687,11 +1669,7 @@ mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }

bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
if (auto arrayAttr = getVector()) {
if (findSegment(*arrayAttr, deviceType))
return true;
}
return false;
return hasDeviceType(getVector(), deviceType);
}

mlir::Value LoopOp::getWorkerValue() {
Expand All @@ -1706,11 +1684,7 @@ mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }

bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
if (auto arrayAttr = getWorker()) {
if (findSegment(*arrayAttr, deviceType))
return true;
}
return false;
return hasDeviceType(getWorker(), deviceType);
}

mlir::Operation::operand_range LoopOp::getTileValues() {
Expand Down Expand Up @@ -1771,11 +1745,7 @@ mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }

bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
if (auto arrayAttr = getGang()) {
if (findSegment(*arrayAttr, deviceType))
return true;
}
return false;
return hasDeviceType(getGang(), deviceType);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1815,11 +1785,7 @@ bool acc::DataOp::hasAsyncOnly() {
}

bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
if (auto arrayAttr = getAsyncOnly()) {
if (findSegment(*arrayAttr, deviceType))
return true;
}
return false;
return hasDeviceType(getAsyncOnly(), deviceType);
}

mlir::Value DataOp::getAsyncValue() {
Expand All @@ -1834,11 +1800,7 @@ mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }

bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
if (auto arrayAttr = getWaitOnly()) {
if (findSegment(*arrayAttr, deviceType))
return true;
}
return false;
return hasDeviceType(getWaitOnly(), deviceType);
}

mlir::Operation::operand_range DataOp::getWaitValues() {
Expand Down Expand Up @@ -2091,20 +2053,6 @@ LogicalResult acc::DeclareOp::verify() {
// RoutineOp
//===----------------------------------------------------------------------===//

static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
mlir::acc::DeviceType deviceType) {
if (!hasDeviceTypeValues(arrayAttr))
return false;

for (auto attr : *arrayAttr) {
auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
if (deviceTypeAttr.getValue() == deviceType)
return true;
}

return false;
}

static unsigned getParallelismForDeviceType(acc::RoutineOp op,
acc::DeviceType dtype) {
unsigned parallelism = 0;
Expand Down

0 comments on commit ee6199c

Please sign in to comment.