-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add argument type and pass which can populate them #2049
base: main
Are you sure you want to change the base?
Conversation
llvm::errs() << "Public fucunction: \"" << funcName | ||
<< "\" not found in argument types map. Argument types " | ||
"must be populated for all public functions.\n"; | ||
signalPassFailure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this fail compilation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, still not sure how we want to handle cases when values are not provided for every argument, or if we should at all. So for now I decided to signalPassFailure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh and for the specific failure this comment is on, I did this since we do want to provide these attributes for public functions since any entrypoints will be public
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it should be a hard error if you supply argument types for only some of public functions. But I agree it's a hard error if you did supply argument types and this doesn't hold len(argumentType) == numFunctionArgs
, but I think you handle this below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this seems too invasive for something that would be an optional optimization opportunity. I agree that the wrong number of arguments/non-existing functions should be hard errors because they suggest the user wanted to explicitly mark types of arguments but made an error in the process. If this is going to be a hard error we probably also need a 4th ArgumentType
UNKNOWN
(maybe we need it regardless of that, what happens if we don't know in advance for some arguments, but know for the others, or if it can have more than one type depending on the context).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed this cause for failure.
bool parse(llvm::cl::Option &O, StringRef argName, | ||
const StringRef &commandLineArg, ArgumentTypeMap &val) { | ||
|
||
std::string arg = commandLineArg.str(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you make a std::string
from StringRef
. StringRef
actually has some very convenient member functions for parsing, that are much easier to use than rudimentary find
on std::string
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to head out that day a little early, but I wanted to get something working people could test out so I did this part a little fast. I'll be making this command line parsing cleaner in the future.
: llvm::cl::parser<ArgumentTypeMap>(O) {} | ||
|
||
// parse - Return true on error. | ||
bool parse(llvm::cl::Option &O, StringRef argName, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You have a very simple CFG with only a handful of nonterminal symbols, it wouldn't take a much more effort to make a proper recursive descent parser than this one. It would make an evolution of this option much easier, and since it's still POC there is a high chance that it will evolve over time (also it would be easier to remove artificial constraints of non-existing whitespace, since white-space is not used as a separator).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I can suggest alternative to this it would be key/value format like this:
--ttir-to-ttnn-backend-pipeline="forward1=input,parameter; forward2=input,parameter"
parser for this should be very simple:
bool parse(llvm::cl::Option &O, llvm::StringRef argName,
const llvm::StringRef &commandLineArg, ArgumentTypeMap &val) {
llvm::StringRef errorMessage = "Invalid format. Expected: function=arg1,arg2; function=arg1,arg2";
llvm::StringRef arg = commandLineArg;
llvm::SmallVector<llvm::StringRef, 8> entries;
arg.split(entries, ';'); // Split functions by `;`
for (llvm::StringRef entry : entries) {
size_t equalPos = entry.find('=');
if (equalPos == llvm::StringRef::npos) {
llvm::errs() << errorMessage << "\n";
return true;
}
llvm::StringRef funcName = entry.take_front(equalPos);
llvm::StringRef argsStr = entry.drop_front(equalPos + 1);
llvm::SmallVector<llvm::StringRef, 8> argNames;
argsStr.split(argNames, ','); // Split arguments by `,`
llvm::SmallVector<ArgumentType, 8> argTypes;
for (llvm::StringRef arg : argNames) {
auto argTypeEnum = ArgumentTypeStringToEnum(arg);
if (!argTypeEnum.has_value()) {
llvm::errs() << "Invalid argument type: " << arg << "\n";
return true;
}
argTypes.push_back(argTypeEnum.value());
}
val[funcName.str()] = {argTypes.begin(), argTypes.end()};
}
return false;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mtopalovicTT this works for me. Although I don't think we can get around the whitespace issue. The opt parser seems to delimit different options themselves with a space, so for individual option values I believe we cannot use spaces.
What I mean by this is that lets say we want to run
--ttir-to-ttnn-backend-pipeline
with a few options set. The options must be delimited with a space:
--ttir-to-ttnn-backend-pipeline="enable-this=true enable-that=false"
So for argument-types=...
the ...
can't contain any spaces. We may be able to stop this by using surrounding quotations ("..."
). I will look into this.
llvm::errs() << "Public fucunction: \"" << funcName | ||
<< "\" not found in argument types map. Argument types " | ||
"must be populated for all public functions.\n"; | ||
signalPassFailure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this seems too invasive for something that would be an optional optimization opportunity. I agree that the wrong number of arguments/non-existing functions should be hard errors because they suggest the user wanted to explicitly mark types of arguments but made an error in the process. If this is going to be a hard error we probably also need a 4th ArgumentType
UNKNOWN
(maybe we need it regardless of that, what happens if we don't know in advance for some arguments, but know for the others, or if it can have more than one type depending on the context).
I might be a bit late here, but why not use a facility that already exist in MLIR: operation arg attributes. An example from the func dialect doc:
and we could have something along the lines of
It seems like these are already managed as an array of attr dictionaries per
|
@vroubtsovTT Im not sure what you mean exactly. I am using the FuncOp arg attrs. The function signature examples you gave are what the output MLIR module would look like. You can also pass one like this and not run the pass. |
I got confused by the team discussion last week, which showed a new FuncOp attribute that was an array of arg input types parallel to the operand array. I now see |
Adding more detail to the PR description would be very helpful, even if the PR is still in draft form. It provides insight into what you are trying to achieve from the end-to-end user perspective. Upon reviewing the code, I believe we may be overcomplicating things unnecessarily. Instead of implementing a pipeline option that needs to be parsed and managed separately, why not use these defined enums and type declarations, allowing users to provide them as part of argument attributes during IR creation? This approach would establish a much simpler end-to-end contract with the user. To sum up, everything said, I see two approaches here:
@func.func(arg0: tensor<1xf32> {ttir.name = "input_1", tt.argument_type = #tt.argument_type<input>}, arg1: ...)
{}
@func.func(arg0: tensor<1xf32>, arg1: tensor<1x32>, ...)
{} ttmlir-opt --ttir-to-ttnn-backend-pipeline="argument-types=..." ...
In my opinion, option 1 is sufficient for all our use cases, while option 2 seems overly complex. However, this is just my viewpoint, and I would appreciate hearing other perspectives as well. |
@sdjordjevicTT So you can populate the argument types during IR creation right now if you wish. If you run the pass it will overwrite the argument types you've already populated. If you pass an empty map it will not run the pass at all (it will return immediately). In tt-torch, we rely on torch-mlir to generate the initial IR in stablehlo. So we cannot populate with our custom attributes unless we are willing to write our own MLIR code in the frontend. Thus the reason for option 2. I should make it more clear that you can provide the argument types without running the pass though. |
This is exactly contextually what I was asking for to be part of the PR description. We are not familiar with the tt-torch frontend, hence we can't provide quality insights. If I understand correctly now, from tt-torch, you just invoke the conversion from torch-mlir to stablehlo-mlir and invoke the tt-mlir compiler with obtained stablehlo mlir? How complex will it be to add an MLIR code to the tt-torch frontend? Just to be clear, I am not against your implemented approach here, let's just put all the options on the table, see the pros\cons, and pick the one that is the best overall. Also, the issue that you are referring to in the PR description is a tt-metal issue, how it is relevant to this PR, am I missing something? |
@sdjordjevicTT That is the wrong issue 🤦, my mistake, it's for another PR I'm working on. I've fixed the description to include the proper issue.
That is correct, we run the StablehloToTTIRPipeline first, then TTIRToTTNNBackendPipeline. We get the initial stablehlo module via the third-party: torch-mlir. torch-mlir is what converts the initial torch.fx graph into mlir.
It would be about as complex as implementing this pass (minus the custom command-line parser). Although I think that writing any code that would modify a TTIR module in a frontend would be undesirable. The argument types will be necessary for optimizations which will be built in to the core compiler. So, if a frontend is not building the initial MLIR module itself (like tt-torch/tt-xla), we should provide a way to populate the argument types that does not put the burden of writing the code that does so in their frontend. |
: llvm::cl::parser<ArgumentTypeMap>(O) {} | ||
|
||
// parse - Return true on error. | ||
bool parse(llvm::cl::Option &O, StringRef argName, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I can suggest alternative to this it would be key/value format like this:
--ttir-to-ttnn-backend-pipeline="forward1=input,parameter; forward2=input,parameter"
parser for this should be very simple:
bool parse(llvm::cl::Option &O, llvm::StringRef argName,
const llvm::StringRef &commandLineArg, ArgumentTypeMap &val) {
llvm::StringRef errorMessage = "Invalid format. Expected: function=arg1,arg2; function=arg1,arg2";
llvm::StringRef arg = commandLineArg;
llvm::SmallVector<llvm::StringRef, 8> entries;
arg.split(entries, ';'); // Split functions by `;`
for (llvm::StringRef entry : entries) {
size_t equalPos = entry.find('=');
if (equalPos == llvm::StringRef::npos) {
llvm::errs() << errorMessage << "\n";
return true;
}
llvm::StringRef funcName = entry.take_front(equalPos);
llvm::StringRef argsStr = entry.drop_front(equalPos + 1);
llvm::SmallVector<llvm::StringRef, 8> argNames;
argsStr.split(argNames, ','); // Split arguments by `,`
llvm::SmallVector<ArgumentType, 8> argTypes;
for (llvm::StringRef arg : argNames) {
auto argTypeEnum = ArgumentTypeStringToEnum(arg);
if (!argTypeEnum.has_value()) {
llvm::errs() << "Invalid argument type: " << arg << "\n";
return true;
}
argTypes.push_back(argTypeEnum.value());
}
val[funcName.str()] = {argTypes.begin(), argTypes.end()};
}
return false;
}
44b70c6
to
73780a9
Compare
73780a9
to
b0cec28
Compare
26f08c2
to
c7f4fed
Compare
c7f4fed
to
dde1b78
Compare
llvm::StringRef argsStr = entry.drop_front(equalPos + 1); | ||
|
||
llvm::SmallVector<llvm::StringRef> argNames; | ||
argsStr.split(argNames, ','); // Split arguments by `,` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check that argNames is > 0
static constexpr StringRef argumentTypes = "argument-types"; | ||
}; | ||
|
||
struct TTArgumentTypeVector { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this wrapper needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was keeping it in line with how some of the TTNN passes which are defined in C++ alone handle the types they hold in llvm::StringMap
. But it could be removed.
void runOnOperation() final { | ||
// Currently, we will allow compile without assigning argument types. | ||
auto map = argumentTypeMap.getValue(); | ||
if (map.size() == 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: map.empty
} | ||
|
||
mlir::ModuleOp module = getOperation(); | ||
std::vector<std::string> funcNames; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: use llvm equivalent for vector and string
auto funcName = func.getName().str(); | ||
funcNames.push_back(funcName); | ||
|
||
if (map.find(funcName) != map.end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can check if funcName exists in map and continue if it doesn't to avoid nesting:
if (map.find(funcName) == map.end())
{
continue;
}
for (uint32_t i = 0; i < func.getNumArguments(); i++) { | ||
// The current argument may already have attributes, so we need to add | ||
// the argument type to that DictonaryAttr rather than overwrite it. | ||
std::optional<mlir::DictionaryAttr> currentArgAttrDict = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can check if argDict is present in if
you don't need optional. Something like below (although I'm not sure if it will compile...)
llvm::SmallVector<mlir::NamedAttribute, 4> newArgAttrs;
if (auto existingAttrs = func.getArgAttrDict(i)) {
for (mlir::NamedAttribute attr : existingAttrs) {
// Overwrite existing "tt.argument_type" while keeping others.
if (attr.getName() == "tt.argument_type") {
llvm::errs() << "WARNING: Overwriting existing argument type "
"attribute for function \"" << funcName
<< "\" argument " << i << "\n";
continue;
}
newArgAttrs.push_back(attr);
}
}
} | ||
|
||
mlir::ModuleOp module = getOperation(); | ||
std::vector<std::string> funcNames; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use llvm::SmallSet for this. You will be able to use contains
on it instead of iterating through whole vector.
} | ||
|
||
mlir::ModuleOp module = getOperation(); | ||
std::vector<std::string> funcNames; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use llvm::SmallSet instead of std::vector. You can later use contains
to check if function name that was inside argTypeMap exists in module.
llvm::errs() << "Function: \"" << kv.first() | ||
<< "\" was provided in the argument types map, however it " | ||
"was not found in module!\n"; | ||
signalPassFailure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to error out here or just warn?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if the user provides a function name it means they think that it exists. So if it isn't that would mean they made a mistake.
dde1b78
to
f940a0d
Compare
This can be used to distinguish inputs, parameters, and constants. Added a pass: "TTPopulateArgumentTypes" to add these attributes to the function arguments according to an llvm:StringMap passed by the user or frontend. The pass can be opted out of if you pass an empty map.
f940a0d
to
1817605
Compare
Issue: #2032
TTPopulateArgumentTypes
pass which can be used to add argument types to the functions of a module.TTIRToTTNNBackedPipeline
option which can be used to populate the argument types of all public functions in the MLIR module.TTPopulateArgumentTypes
is added toTTIRToTTNNBackendPipeline
FuncOp
sarg_attrs
attribute.std::unordered_map
, in which case the pass that populates the attributes returns immediately - doing nothing.tt.argument_type
attributes. If this pass is run, the existingtt.argument_type
attributes will be overwritten by the ones specified inTTIRToTTNNBackendPipelineOptions
.TTIRToTTNNBackendPipelineOptions
as the default (empty map) causes the pass to not modify the module.