Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add argument type and pass which can populate them #2049

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

LPanosTT
Copy link
Contributor

@LPanosTT LPanosTT commented Jan 30, 2025

Issue: #2032

  • Create TTPopulateArgumentTypes pass which can be used to add argument types to the functions of a module.
  • Provide TTIRToTTNNBackedPipeline option which can be used to populate the argument types of all public functions in the MLIR module. TTPopulateArgumentTypes is added to TTIRToTTNNBackendPipeline
    • Argument types are stored in the FuncOps arg_attrs attribute.
  • If this option is unpopulated it defaults to an empty std::unordered_map, in which case the pass that populates the attributes returns immediately - doing nothing.
  • If any arg attributes are populated before running this pass, they will remain. Except any existing tt.argument_type attributes. If this pass is run, the existing tt.argument_type attributes will be overwritten by the ones specified in TTIRToTTNNBackendPipelineOptions.
    • This means that a user of this pipeline is free to add these attributes on their own if they wish. To avoid overwriting these values they can simply not set the argument map in TTIRToTTNNBackendPipelineOptions as the default (empty map) causes the pass to not modify the module.

llvm::errs() << "Public fucunction: \"" << funcName
<< "\" not found in argument types map. Argument types "
"must be populated for all public functions.\n";
signalPassFailure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this fail compilation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, still not sure how we want to handle cases when values are not provided for every argument, or if we should at all. So for now I decided to signalPassFailure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh and for the specific failure this comment is on, I did this since we do want to provide these attributes for public functions since any entrypoints will be public

Copy link
Contributor

@nsmithtt nsmithtt Jan 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it should be a hard error if you supply argument types for only some of public functions. But I agree it's a hard error if you did supply argument types and this doesn't hold len(argumentType) == numFunctionArgs, but I think you handle this below.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this seems too invasive for something that would be an optional optimization opportunity. I agree that the wrong number of arguments/non-existing functions should be hard errors because they suggest the user wanted to explicitly mark types of arguments but made an error in the process. If this is going to be a hard error we probably also need a 4th ArgumentType UNKNOWN (maybe we need it regardless of that, what happens if we don't know in advance for some arguments, but know for the others, or if it can have more than one type depending on the context).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed this cause for failure.

bool parse(llvm::cl::Option &O, StringRef argName,
const StringRef &commandLineArg, ArgumentTypeMap &val) {

std::string arg = commandLineArg.str();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you make a std::string from StringRef. StringRef actually has some very convenient member functions for parsing, that are much easier to use than rudimentary find on std::string.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to head out that day a little early, but I wanted to get something working people could test out so I did this part a little fast. I'll be making this command line parsing cleaner in the future.

: llvm::cl::parser<ArgumentTypeMap>(O) {}

// parse - Return true on error.
bool parse(llvm::cl::Option &O, StringRef argName,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have a very simple CFG with only a handful of nonterminal symbols, it wouldn't take a much more effort to make a proper recursive descent parser than this one. It would make an evolution of this option much easier, and since it's still POC there is a high chance that it will evolve over time (also it would be easier to remove artificial constraints of non-existing whitespace, since white-space is not used as a separator).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I can suggest alternative to this it would be key/value format like this:

--ttir-to-ttnn-backend-pipeline="forward1=input,parameter; forward2=input,parameter"

parser for this should be very simple:

bool parse(llvm::cl::Option &O, llvm::StringRef argName,
           const llvm::StringRef &commandLineArg, ArgumentTypeMap &val) {
  llvm::StringRef errorMessage = "Invalid format. Expected: function=arg1,arg2; function=arg1,arg2";
  llvm::StringRef arg = commandLineArg;

  llvm::SmallVector<llvm::StringRef, 8> entries;
  arg.split(entries, ';'); // Split functions by `;`

  for (llvm::StringRef entry : entries) {
    size_t equalPos = entry.find('=');
    if (equalPos == llvm::StringRef::npos) {
      llvm::errs() << errorMessage << "\n";
      return true;
    }

    llvm::StringRef funcName = entry.take_front(equalPos);
    llvm::StringRef argsStr = entry.drop_front(equalPos + 1);

    llvm::SmallVector<llvm::StringRef, 8> argNames;
    argsStr.split(argNames, ','); // Split arguments by `,`

    llvm::SmallVector<ArgumentType, 8> argTypes;
    for (llvm::StringRef arg : argNames) {
      auto argTypeEnum = ArgumentTypeStringToEnum(arg);
      if (!argTypeEnum.has_value()) {
        llvm::errs() << "Invalid argument type: " << arg << "\n";
        return true;
      }
      argTypes.push_back(argTypeEnum.value());
    }

    val[funcName.str()] = {argTypes.begin(), argTypes.end()};
  }

  return false;
}

Copy link
Contributor Author

@LPanosTT LPanosTT Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mtopalovicTT this works for me. Although I don't think we can get around the whitespace issue. The opt parser seems to delimit different options themselves with a space, so for individual option values I believe we cannot use spaces.

What I mean by this is that lets say we want to run

--ttir-to-ttnn-backend-pipeline

with a few options set. The options must be delimited with a space:

--ttir-to-ttnn-backend-pipeline="enable-this=true enable-that=false"

So for argument-types=... the ... can't contain any spaces. We may be able to stop this by using surrounding quotations ("..."). I will look into this.

llvm::errs() << "Public fucunction: \"" << funcName
<< "\" not found in argument types map. Argument types "
"must be populated for all public functions.\n";
signalPassFailure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this seems too invasive for something that would be an optional optimization opportunity. I agree that the wrong number of arguments/non-existing functions should be hard errors because they suggest the user wanted to explicitly mark types of arguments but made an error in the process. If this is going to be a hard error we probably also need a 4th ArgumentType UNKNOWN (maybe we need it regardless of that, what happens if we don't know in advance for some arguments, but know for the others, or if it can have more than one type depending on the context).

@vroubtsovTT
Copy link
Contributor

vroubtsovTT commented Feb 2, 2025

I might be a bit late here, but why not use a facility that already exist in MLIR: operation arg attributes.

An example from the func dialect doc:

// A function with an argument attribute
func.func private @example_fn_arg(%x: i32 {swift.self = unit})

and we could have something along the lines of

func.func @main(%arg0: tensor<1x128x128x384xf32> {tt.input_type = "input"},
    %arg1: tensor<384xf32>  {tt.input_type = "parameter"},  ...) -> tensor<1x132x132x384xf32> 
    ...
}

It seems like these are already managed as an array of attr dictionaries per func::FuncOp and there are build() method overloads that accept them. At the ODS level FuncOp defines these as

  let arguments = (ins ... OptionalAttr<DictArrayAttr>:$arg_attrs, ...);

@LPanosTT
Copy link
Contributor Author

LPanosTT commented Feb 2, 2025

@vroubtsovTT Im not sure what you mean exactly. I am using the FuncOp arg attrs. The function signature examples you gave are what the output MLIR module would look like. You can also pass one like this and not run the pass.

@vroubtsovTT
Copy link
Contributor

@vroubtsovTT Im not sure what you mean exactly. I am using the FuncOp arg attrs. The function signature examples you gave are what the output MLIR module would look like. You can also pass one like this and not run the pass.

I got confused by the team discussion last week, which showed a new FuncOp attribute that was an array of arg input types parallel to the operand array. I now see func.getArgAttrDict() in your changes, so we are all good.

@sdjordjevicTT
Copy link
Contributor

sdjordjevicTT commented Feb 3, 2025

Adding more detail to the PR description would be very helpful, even if the PR is still in draft form. It provides insight into what you are trying to achieve from the end-to-end user perspective.

Upon reviewing the code, I believe we may be overcomplicating things unnecessarily. Instead of implementing a pipeline option that needs to be parsed and managed separately, why not use these defined enums and type declarations, allowing users to provide them as part of argument attributes during IR creation? This approach would establish a much simpler end-to-end contract with the user.

To sum up, everything said, I see two approaches here:

  1. User creates on their own a func arg attribute, they embed that information as a part of TTIR creation.
@func.func(arg0: tensor<1xf32> {ttir.name = "input_1", tt.argument_type = #tt.argument_type<input>}, arg1: ...)
{}
- Pros: Easier to implement, clearer contract between user and compiler
- Cons: IR needs to be modified if you want to change the argument type, but my assumption is that would be a rare case
  1. The user creates IR without func arg attributes and specifies the pipeline option to compile as you suggested
@func.func(arg0: tensor<1xf32>, arg1: tensor<1x32>, ...)
{}
ttmlir-opt --ttir-to-ttnn-backend-pipeline="argument-types=..." ...
  • Pros: More flexible, don't require changes in IR if we want to change argument type
  • Cons: More complex, additional pass to populate args, loosely coupled contract with the user (not clear when to fail to compile, when to emit error, etc.)

In my opinion, option 1 is sufficient for all our use cases, while option 2 seems overly complex. However, this is just my viewpoint, and I would appreciate hearing other perspectives as well.

@LPanosTT
Copy link
Contributor Author

LPanosTT commented Feb 3, 2025

@sdjordjevicTT So you can populate the argument types during IR creation right now if you wish. If you run the pass it will overwrite the argument types you've already populated. If you pass an empty map it will not run the pass at all (it will return immediately). In tt-torch, we rely on torch-mlir to generate the initial IR in stablehlo. So we cannot populate with our custom attributes unless we are willing to write our own MLIR code in the frontend. Thus the reason for option 2.

I should make it more clear that you can provide the argument types without running the pass though.

@sdjordjevicTT
Copy link
Contributor

@sdjordjevicTT So you can populate the argument types during IR creation right now if you wish. If you run the pass it will overwrite the argument types you've already populated. If you pass an empty map it will not run the pass at all (it will return immediately). In tt-torch, we rely on torch-mlir to generate the initial IR in stablehlo. So we cannot populate with our custom attributes unless we are willing to write our own MLIR code in the frontend. Thus the reason for option 2.

I should make it more clear that you can provide the argument types without running the pass though.

This is exactly contextually what I was asking for to be part of the PR description. We are not familiar with the tt-torch frontend, hence we can't provide quality insights. If I understand correctly now, from tt-torch, you just invoke the conversion from torch-mlir to stablehlo-mlir and invoke the tt-mlir compiler with obtained stablehlo mlir?

How complex will it be to add an MLIR code to the tt-torch frontend? Just to be clear, I am not against your implemented approach here, let's just put all the options on the table, see the pros\cons, and pick the one that is the best overall.

Also, the issue that you are referring to in the PR description is a tt-metal issue, how it is relevant to this PR, am I missing something?

@LPanosTT
Copy link
Contributor Author

LPanosTT commented Feb 3, 2025

@sdjordjevicTT That is the wrong issue 🤦, my mistake, it's for another PR I'm working on. I've fixed the description to include the proper issue.

If I understand correctly now, from tt-torch, you just invoke the conversion from torch-mlir to stablehlo-mlir and invoke the tt-mlir compiler with obtained stablehlo mlir?

That is correct, we run the StablehloToTTIRPipeline first, then TTIRToTTNNBackendPipeline. We get the initial stablehlo module via the third-party: torch-mlir. torch-mlir is what converts the initial torch.fx graph into mlir.

How complex will it be to add an MLIR code to the tt-torch frontend? Just to be clear, I am not against your implemented approach here, let's just put all the options on the table, see the pros\cons, and pick the one that is the best overall.

It would be about as complex as implementing this pass (minus the custom command-line parser). Although I think that writing any code that would modify a TTIR module in a frontend would be undesirable. The argument types will be necessary for optimizations which will be built in to the core compiler. So, if a frontend is not building the initial MLIR module itself (like tt-torch/tt-xla), we should provide a way to populate the argument types that does not put the burden of writing the code that does so in their frontend.

include/ttmlir/Dialect/TTIR/Transforms/Passes.td Outdated Show resolved Hide resolved
include/ttmlir/Dialect/TTIR/Transforms/Passes.td Outdated Show resolved Hide resolved
: llvm::cl::parser<ArgumentTypeMap>(O) {}

// parse - Return true on error.
bool parse(llvm::cl::Option &O, StringRef argName,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I can suggest alternative to this it would be key/value format like this:

--ttir-to-ttnn-backend-pipeline="forward1=input,parameter; forward2=input,parameter"

parser for this should be very simple:

bool parse(llvm::cl::Option &O, llvm::StringRef argName,
           const llvm::StringRef &commandLineArg, ArgumentTypeMap &val) {
  llvm::StringRef errorMessage = "Invalid format. Expected: function=arg1,arg2; function=arg1,arg2";
  llvm::StringRef arg = commandLineArg;

  llvm::SmallVector<llvm::StringRef, 8> entries;
  arg.split(entries, ';'); // Split functions by `;`

  for (llvm::StringRef entry : entries) {
    size_t equalPos = entry.find('=');
    if (equalPos == llvm::StringRef::npos) {
      llvm::errs() << errorMessage << "\n";
      return true;
    }

    llvm::StringRef funcName = entry.take_front(equalPos);
    llvm::StringRef argsStr = entry.drop_front(equalPos + 1);

    llvm::SmallVector<llvm::StringRef, 8> argNames;
    argsStr.split(argNames, ','); // Split arguments by `,`

    llvm::SmallVector<ArgumentType, 8> argTypes;
    for (llvm::StringRef arg : argNames) {
      auto argTypeEnum = ArgumentTypeStringToEnum(arg);
      if (!argTypeEnum.has_value()) {
        llvm::errs() << "Invalid argument type: " << arg << "\n";
        return true;
      }
      argTypes.push_back(argTypeEnum.value());
    }

    val[funcName.str()] = {argTypes.begin(), argTypes.end()};
  }

  return false;
}

@LPanosTT LPanosTT force-pushed the lpanos/argument_types branch 4 times, most recently from 44b70c6 to 73780a9 Compare February 5, 2025 15:33
@LPanosTT LPanosTT marked this pull request as ready for review February 5, 2025 15:33
@LPanosTT LPanosTT force-pushed the lpanos/argument_types branch from 73780a9 to b0cec28 Compare February 5, 2025 15:37
@LPanosTT LPanosTT force-pushed the lpanos/argument_types branch 2 times, most recently from 26f08c2 to c7f4fed Compare February 5, 2025 20:02
@LPanosTT LPanosTT force-pushed the lpanos/argument_types branch from c7f4fed to dde1b78 Compare February 5, 2025 22:52
llvm::StringRef argsStr = entry.drop_front(equalPos + 1);

llvm::SmallVector<llvm::StringRef> argNames;
argsStr.split(argNames, ','); // Split arguments by `,`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check that argNames is > 0

static constexpr StringRef argumentTypes = "argument-types";
};

struct TTArgumentTypeVector {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this wrapper needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was keeping it in line with how some of the TTNN passes which are defined in C++ alone handle the types they hold in llvm::StringMap. But it could be removed.

void runOnOperation() final {
// Currently, we will allow compile without assigning argument types.
auto map = argumentTypeMap.getValue();
if (map.size() == 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: map.empty

}

mlir::ModuleOp module = getOperation();
std::vector<std::string> funcNames;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use llvm equivalent for vector and string

auto funcName = func.getName().str();
funcNames.push_back(funcName);

if (map.find(funcName) != map.end()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can check if funcName exists in map and continue if it doesn't to avoid nesting:

if (map.find(funcName) == map.end())
{
     continue;
}

for (uint32_t i = 0; i < func.getNumArguments(); i++) {
// The current argument may already have attributes, so we need to add
// the argument type to that DictonaryAttr rather than overwrite it.
std::optional<mlir::DictionaryAttr> currentArgAttrDict =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can check if argDict is present in if you don't need optional. Something like below (although I'm not sure if it will compile...)

llvm::SmallVector<mlir::NamedAttribute, 4> newArgAttrs;
if (auto existingAttrs = func.getArgAttrDict(i)) {
    for (mlir::NamedAttribute attr : existingAttrs) {
        // Overwrite existing "tt.argument_type" while keeping others.
        if (attr.getName() == "tt.argument_type") {
            llvm::errs() << "WARNING: Overwriting existing argument type "
                            "attribute for function \"" << funcName
                         << "\" argument " << i << "\n";
            continue;
        }
        newArgAttrs.push_back(attr);
    }
}

}

mlir::ModuleOp module = getOperation();
std::vector<std::string> funcNames;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use llvm::SmallSet for this. You will be able to use contains on it instead of iterating through whole vector.

}

mlir::ModuleOp module = getOperation();
std::vector<std::string> funcNames;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use llvm::SmallSet instead of std::vector. You can later use contains to check if function name that was inside argTypeMap exists in module.

llvm::errs() << "Function: \"" << kv.first()
<< "\" was provided in the argument types map, however it "
"was not found in module!\n";
signalPassFailure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to error out here or just warn?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if the user provides a function name it means they think that it exists. So if it isn't that would mean they made a mistake.

@LPanosTT LPanosTT force-pushed the lpanos/argument_types branch from dde1b78 to f940a0d Compare February 7, 2025 13:49
This can be used to distinguish inputs, parameters, and constants. Added
a pass: "TTPopulateArgumentTypes" to add these attributes to the
function arguments according to an llvm:StringMap passed by the user or
frontend. The pass can be opted out of if you pass an empty map.
@LPanosTT LPanosTT force-pushed the lpanos/argument_types branch from f940a0d to 1817605 Compare February 7, 2025 23:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants