Skip to content

Commit

Permalink
Add Attribute Name and Type Mapping System (#1117)
Browse files Browse the repository at this point in the history
### Ticket
fixes: #1116 

### What's changed
This PR introduces a new attribute mapping system that provides a clean
way to handle attribute name remapping and type conversion when lowering
from Forge-FE to MLIR.

Key Changes:
- Added new AttributeMapper class that manages attribute name mappings
and type conversions
- Updated MLIRGenerator to use AttributeMapper for attribute handling
- Added default mappings for operations like repeat_interleave

The new system allows:
- Remapping attribute names (e.g., "kernel_size" -> "kernel_shape")
- Specifying target types for MLIR conversion (e.g., int -> uint32_t)
- Easy addition of new operation mappings
- Better separation of concerns between mapping and conversion

Example default mappings include:
- repeat_interleave: "repeats" as uint32
  • Loading branch information
mstojkovicTT authored Jan 28, 2025
1 parent 057277d commit eac7778
Showing 1 changed file with 92 additions and 14 deletions.
106 changes: 92 additions & 14 deletions forge/csrc/passes/lower_to_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cstdint>
#include <stdexcept>
#include <string>
#include <utils/assert.hpp>

// TTForge headers
#include "forge_graph_module.hpp"
Expand Down Expand Up @@ -50,6 +51,64 @@ using namespace tt;
* @brief Implementation of TT-MLIR emission from the Forge module (set of graphs).
*/

enum class TargetType
{
SourceType,
UInt32,
};

struct AttributeRemap
{
std::optional<std::string> new_name; // New name for the attribute
TargetType target_type_value; // Target type conversion

AttributeRemap(const std::optional<std::string> &name = std::nullopt, TargetType type = TargetType::SourceType) :
new_name(name), target_type_value(type)
{
}
};

class AttributeMapper
{
public:
AttributeMapper() { initialize_default_mappings(); }

// Add mapping for a specific op's attribute
void add_op_mapping(const std::string &op_name, const std::string &attr_name, const AttributeRemap &remap)
{
mappings_[op_name][attr_name] = remap;
}

// Get the mapped name and target type for an attribute
std::pair<std::string, TargetType> get_mapped_name_and_type(
const std::string &op_name, const std::string &attr_name) const
{
auto op_it = mappings_.find(op_name);
if (op_it != mappings_.end())
{
auto attr_it = op_it->second.find(attr_name);
if (attr_it != op_it->second.end())
{
const auto &remap = attr_it->second;
return {remap.new_name.value_or(attr_name), remap.target_type_value};
}
}
return {attr_name, TargetType::SourceType};
}

private:
// Mapping storage: op_name -> (attr_name -> remap)
std::map<std::string, std::map<std::string, AttributeRemap>> mappings_;

void initialize_default_mappings()
{
// repeat_interleave configuration
add_op_mapping("repeat_interleave", "repeats", AttributeRemap(std::nullopt, TargetType::UInt32));

// Add more default mappings here
}
};

class MLIRGenerator
{
public:
Expand Down Expand Up @@ -147,6 +206,9 @@ class MLIRGenerator
/// Map of lowering handlers for ttforge operations to MLIR.
std::map<std::string, HandlerType> lowering_handler_map;

/// Attribute mapper for handling attribute conversions
static AttributeMapper attr_mapper_;

/// Declares a variable in the current (only) scope.
/// The declaration corresponds to exactly one operation node in the TTForge graph.
void declare(graphlib::Node *node, mlir::Value value)
Expand All @@ -162,8 +224,21 @@ class MLIRGenerator
}

// Convert a TTForge attribute to an MLIR attribute.
mlir::Attribute convert_to_mlir_attribute(const tt::ForgeOpAttr &value)
mlir::Attribute convert_to_mlir_attribute(const tt::ForgeOpAttr &value, TargetType target_type)
{
if (target_type != TargetType::SourceType)
{
// Convert the attribute to the target type
switch (target_type)
{
case TargetType::UInt32:
TT_ASSERT(std::get<int>(value) >= 0, "Value must be an >= 0 for conversion to uint32");
return builder_.getUI32IntegerAttr(static_cast<uint32_t>(std::get<int>(value)));
default:
// If type not handled, throw an exception
throw std::runtime_error("Unhandled target type conversion");
}
}
return std::visit(
[this](auto &&arg) -> mlir::Attribute
{
Expand Down Expand Up @@ -357,8 +432,17 @@ class MLIRGenerator
// Evaluate operation operands: inputs and outputs per DPS
llvm::SmallVector<mlir::Value> operands = get_mlir_operands(graph, op_node);

// Evaluate opeartion attributes
llvm::SmallVector<mlir::NamedAttribute> attributes;
// Map forge to MLIR attributes for this operation.
llvm::SmallVector<mlir::NamedAttribute> mlir_attributes;
for (const auto &[name, value] : op_node->op_type().named_attrs)
{
auto [mapped_name, target_type] = attr_mapper_.get_mapped_name_and_type(op_node->op_name(), name);

mlir_attributes.push_back(
builder_.getNamedAttr(mapped_name, convert_to_mlir_attribute(value, target_type)));
}

// Handle operation segment sizes if needed
::llvm::ArrayRef<::llvm::StringRef> operation_attributes = TTIROp::getAttributeNames();
for (auto attribute_name : operation_attributes)
{
Expand All @@ -369,23 +453,15 @@ class MLIRGenerator
mlir::OpTrait::AttrSizedOperandSegments<void>::getOperandSegmentSizeAttr(),
builder_.getDenseI32ArrayAttr(
{static_cast<int32_t>(graph->operands(op_node).size()), static_cast<int32_t>(1)}));
attributes.push_back(operand_segment_sizes_attribute);
mlir_attributes.push_back(operand_segment_sizes_attribute);
}
}

for (const auto &attribute : op_node->op_type().named_attrs)
{
// convert atribute to mlir atribute
auto mlir_atribute = convert_to_mlir_attribute(attribute.second);
mlir::NamedAttribute named_attribute = builder_.getNamedAttr(attribute.first, mlir_atribute);
attributes.push_back(named_attribute);
}

auto op = builder_.create<TTIROp>(
get_tt_forge_operation_location(graph, op_node),
mlir::TypeRange(return_types),
mlir::ValueRange(operands),
attributes);
mlir_attributes);

return op.getOperation()->getResult(0);
}
Expand Down Expand Up @@ -532,9 +608,9 @@ class MLIRGenerator
lowering_handler_map["concatenate"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ConcatOp>;
lowering_handler_map["conv2d"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::Conv2dOp>;
lowering_handler_map["cosine"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::CosOp>;
lowering_handler_map["embedding"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::EmbeddingOp>;
lowering_handler_map["embedding_bw"] =
&MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::EmbeddingBackwardOp>;
lowering_handler_map["embedding"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::EmbeddingOp>;
lowering_handler_map["equal"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::EqualOp>;
lowering_handler_map["exp"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::ExpOp>;
lowering_handler_map["gelu"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::GeluOp>;
Expand Down Expand Up @@ -567,6 +643,8 @@ class MLIRGenerator
lowering_handler_map["unsqueeze"] = &MLIRGenerator::emit_mlir_ttforge_op<mlir::tt::ttir::UnsqueezeOp>;
}
};

AttributeMapper MLIRGenerator::attr_mapper_;
} // namespace
namespace tt::passes
{
Expand Down

0 comments on commit eac7778

Please sign in to comment.