Skip to content

Commit

Permalink
Applied Milans suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
LPanosTT committed Feb 10, 2025
1 parent 8f684f7 commit fc2ce42
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 48 deletions.
5 changes: 5 additions & 0 deletions lib/Dialect/TT/Utils/PassOverrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ bool ArgumentTypeMapParser::parse(llvm::cl::Option &O, llvm::StringRef argName,

llvm::SmallVector<llvm::StringRef> argNames;
argsStr.split(argNames, ','); // Split arguments by `,`
if (argNames.empty()) {
llvm::errs() << "Provided empty argument list for funtion name: \""
<< funcName << "\"" << "\n";
return true;
}

llvm::SmallVector<ArgumentType> argTypes;
for (llvm::StringRef arg : argNames) {
Expand Down
92 changes: 44 additions & 48 deletions lib/Dialect/TT/Utils/PopulateArgumentTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include <llvm/ADT/SmallSet.h>

namespace mlir::tt {

Expand Down Expand Up @@ -136,72 +137,67 @@ class TTPopulateArgumentTypes
void runOnOperation() final {
// Currently, we will allow compile without assigning argument types.
auto map = argumentTypeMap.getValue();
if (map.size() == 0) {
if (map.empty()) {
llvm::errs()
<< "WARNING: Empty argument type map provided. Skipping argument "
"type population. This may affect subsequent compile steps.\n";
return;
}

mlir::ModuleOp module = getOperation();
std::vector<std::string> funcNames;
llvm::SmallSet<StringRef, 8> funcNames;

// Iterate through every function as we may be assigning argument types to
// them.
for (auto func : module.getOps<mlir::func::FuncOp>()) {
auto funcName = func.getName().str();
funcNames.push_back(funcName);

if (map.find(funcName) != map.end()) {
std::vector<mlir::Attribute> argTypeAttrs;
for (auto argType : map.at(funcName).argumentTypes) {
argTypeAttrs.push_back(
mlir::tt::ArgumentTypeAttr::get(&getContext(), argType));
}
if (func.getNumArguments() != argTypeAttrs.size()) {
llvm::errs() << "Function: \"" << funcName
<< "\" argument count mismatch.\n";
signalPassFailure();
}
// Need to update/create the DictionaryAttr for each corresponding
// function argument.
for (uint32_t i = 0; i < func.getNumArguments(); i++) {
// The current argument may already have attributes, so we need to add
// the argument type to that DictonaryAttr rather than overwrite it.
std::optional<mlir::DictionaryAttr> currentArgAttrDict =
func.getArgAttrDict(i)
? std::make_optional(mlir::cast<mlir::DictionaryAttr>(
func.getArgAttrDict(i)))
: std::nullopt;
std::vector<mlir::NamedAttribute> newArgAttrs;
if (currentArgAttrDict.has_value()) {
for (mlir::NamedAttribute currentArgAttr :
currentArgAttrDict.value()) {
// If this argument already has an argumnet type, this pass wil
// overwrite it. Log a warning.
if (currentArgAttr.getName() != "tt.argument_type") {
newArgAttrs.push_back(currentArgAttr);
} else {
llvm::errs() << "WARNING: Overwriting existing argument type "
"attribute for function: \""
<< funcName << "\" argument: " << i << "\n";
}
auto funcName = func.getName();
funcNames.insert(funcName);

if (map.find(funcName) == map.end()) {
continue;
}

std::vector<mlir::Attribute> argTypeAttrs;
for (auto argType : map.at(funcName).argumentTypes) {
argTypeAttrs.push_back(
mlir::tt::ArgumentTypeAttr::get(&getContext(), argType));
}
if (func.getNumArguments() != argTypeAttrs.size()) {
llvm::errs() << "Function: \"" << funcName
<< "\" argument count mismatch.\n";
signalPassFailure();
}
// Need to update/create the DictionaryAttr for each corresponding
// function argument.
for (uint32_t i = 0; i < func.getNumArguments(); i++) {
// The current argument may already have attributes, so we need to add
// the argument type to that DictonaryAttr rather than overwrite it.
std::vector<mlir::NamedAttribute> newArgAttrs;
if (auto currentArgAttrDict = func.getArgAttrDict(i)) {
for (mlir::NamedAttribute currentArgAttr : currentArgAttrDict) {
// If this argument already has an argumnet type, this pass wil
// overwrite it. Log a warning.
if (currentArgAttr.getName() != "tt.argument_type") {
newArgAttrs.push_back(currentArgAttr);
} else {
llvm::errs() << "WARNING: Overwriting existing argument type "
"attribute for function: \""
<< funcName << "\" argument: " << i << "\n";
}
}
mlir::NamedAttribute attr(
mlir::StringAttr::get(&getContext(), "tt.argument_type"),
argTypeAttrs[i]);
newArgAttrs.push_back(attr);

func.setArgAttrs(
i, mlir::DictionaryAttr::get(&getContext(), newArgAttrs));
}
mlir::NamedAttribute attr(
mlir::StringAttr::get(&getContext(), "tt.argument_type"),
argTypeAttrs[i]);
newArgAttrs.push_back(attr);

func.setArgAttrs(i,
mlir::DictionaryAttr::get(&getContext(), newArgAttrs));
}
}

for (auto &kv : map) {
if (std::find(funcNames.begin(), funcNames.end(), kv.first()) ==
funcNames.end()) {
if (!funcNames.contains(kv.first())) {
llvm::errs() << "Function: \"" << kv.first()
<< "\" was provided in the argument types map, however it "
"was not found in module!\n";
Expand Down

0 comments on commit fc2ce42

Please sign in to comment.