From 732a4126fb73cd4f873a7075cf133b4666877267 Mon Sep 17 00:00:00 2001 From: mingzheTerapines Date: Fri, 2 Aug 2024 14:28:52 +0800 Subject: [PATCH] Support union type --- lib/Dialect/Moore/MooreOps.cpp | 95 +++++++++++++++------------------- 1 file changed, 43 insertions(+), 52 deletions(-) diff --git a/lib/Dialect/Moore/MooreOps.cpp b/lib/Dialect/Moore/MooreOps.cpp index 5a638f012581..22ee8911d082 100644 --- a/lib/Dialect/Moore/MooreOps.cpp +++ b/lib/Dialect/Moore/MooreOps.cpp @@ -554,7 +554,11 @@ static std::optional getStructFieldIndex(Type type, StringAttr name) { return structType.getFieldIndex(name); if (auto structType = dyn_cast(type)) return structType.getFieldIndex(name); - assert(0 && "expected StructType or UnpackedStructType"); + if (auto unionType = dyn_cast(type)) + return unionType.getFieldIndex(name); + if (auto unionType = dyn_cast(type)) + return unionType.getFieldIndex(name); + assert(0 && "expected Struct-Like Type"); return {}; } @@ -563,6 +567,10 @@ static ArrayRef getStructMembers(Type type) { return structType.getMembers(); if (auto structType = dyn_cast(type)) return structType.getMembers(); + if (auto unionType = dyn_cast(type)) + return unionType.getMembers(); + if (auto unionType = dyn_cast(type)) + return unionType.getMembers(); return {}; } @@ -760,23 +768,18 @@ LogicalResult StructInjectOp::canonicalize(StructInjectOp op, //===----------------------------------------------------------------------===// LogicalResult UnionCreateOp::verify() { - /// checks if the types of the input is exactly equal to the union field + auto type = getStructFieldType(getType(), getFieldNameAttr()); + + /// checks if the type of the input is exactly equal to the union field /// type - return TypeSwitch(getType()) - .Case([this](auto &type) { - auto members = type.getMembers(); - auto resultType = getType(); - auto fieldName = getFieldName(); - for (const auto &member : members) - if (member.name == fieldName && member.type == resultType) - return success(); - emitOpError("input type must match the union field type"); - return failure(); - }) - .Default([this](auto &) { - emitOpError("input type must be UnionType or UnpackedUnionType"); - return failure(); - }); + + if (!type) + return emitOpError() << "union field " << getFieldNameAttr() + << " which does not exist in " << getInput().getType(); + if (type != getType()) + return emitOpError() << "result type " << getType() + << " must match union field type " << type; + return success(); } //===----------------------------------------------------------------------===// @@ -784,47 +787,35 @@ LogicalResult UnionCreateOp::verify() { //===----------------------------------------------------------------------===// LogicalResult UnionExtractOp::verify() { - /// checks if the types of the input is exactly equal to the one of the - /// types of the result union fields - return TypeSwitch(getInput().getType()) - .Case([this](auto &type) { - auto members = type.getMembers(); - auto fieldName = getFieldName(); - auto resultType = getType(); - for (const auto &member : members) - if (member.name == fieldName && member.type == resultType) - return success(); - emitOpError("result type must match the union field type"); - return failure(); - }) - .Default([this](auto &) { - emitOpError("input type must be UnionType or UnpackedUnionType"); - return failure(); - }); + auto type = getStructFieldType(getInput().getType(), getFieldNameAttr()); + + /// checks if the type of the input is exactly equal to the type of the result + /// union fields + if (!type) + return emitOpError() << "union field " << getFieldNameAttr() + << " which does not exist in " << getInput().getType(); + if (type != getType()) + return emitOpError() << "result type " << getType() + << " must match union field type " << type; + return success(); } //===----------------------------------------------------------------------===// -// UnionExtractOp +// UnionExtractRefOp //===----------------------------------------------------------------------===// LogicalResult UnionExtractRefOp::verify() { - /// checks if the types of the result is exactly equal to the type of the - /// refe union field - return TypeSwitch(getInput().getType().getNestedType()) - .Case([this](auto &type) { - auto members = type.getMembers(); - auto fieldName = getFieldName(); - auto resultType = getType().getNestedType(); - for (const auto &member : members) - if (member.name == fieldName && member.type == resultType) - return success(); - emitOpError("result type must match the union field type"); - return failure(); - }) - .Default([this](auto &) { - emitOpError("input type must be UnionType or UnpackedUnionType"); - return failure(); - }); + auto type = getStructFieldType(getInput().getType().getNestedType(), + getFieldNameAttr()); + /// checks if the type of the result is exactly equal to the type of the + /// referring union field + if (!type) + return emitOpError() << "union field " << getFieldNameAttr() + << " which does not exist in " << getInput().getType(); + if (type != getType()) + return emitOpError() << "result type " << getType() + << " must match union field type " << type; + return success(); } //===----------------------------------------------------------------------===//