Skip to content

Commit

Permalink
Include comments with template argument names in Cpp code from EmitC (#…
Browse files Browse the repository at this point in the history
…403)

* Include comments with template arg names in Cpp code from EmitC

* Apply suggestions from code review

Co-authored-by: Corentin Ferry <corentin.ferry@amd.com>
Co-authored-by: Matthias Gehre <matthias.gehre@amd.com>

* Test for the presence of template arg names when there are no template args

---------

Co-authored-by: Corentin Ferry <corentin.ferry@amd.com>
Co-authored-by: Matthias Gehre <matthias.gehre@amd.com>
  • Loading branch information
3 people authored Nov 27, 2024
1 parent 20a6720 commit e25d207
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 4 deletions.
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
Arg<StrAttr, "the C++ function to call">:$callee,
Arg<OptionalAttr<ArrayAttr>, "the order of operands and further attributes">:$args,
Arg<OptionalAttr<ArrayAttr>, "template arguments">:$template_args,
Arg<OptionalAttr<StrArrayAttr>, "template argument names">:$template_arg_names,
Variadic<EmitCType>:$operands
);
let results = (outs Variadic<EmitCType>);
Expand All @@ -302,7 +303,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
"::mlir::ValueRange":$operands,
CArg<"::mlir::ArrayAttr", "{}">:$args,
CArg<"::mlir::ArrayAttr", "{}">:$template_args), [{
build($_builder, $_state, resultTypes, callee, args, template_args,
build($_builder, $_state, resultTypes, callee, args, template_args, {},
operands);
}]
>
Expand Down
13 changes: 13 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,19 @@ LogicalResult emitc::CallOpaqueOp::verify() {
}
}

if (std::optional<ArrayAttr> templateArgNames = getTemplateArgNames()) {
if (std::optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
if ((*templateArgNames).size() &&
(*templateArgNames).size() != (*templateArgsAttr).size()) {
return emitOpError("number of template argument names must be equal to "
"number of template arguments");
}
} else {
return emitOpError("should not have names for template arguments if it "
"does not have template arguments");
}
}

if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
return emitOpError() << "cannot return array type";
}
Expand Down
26 changes: 23 additions & 3 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,11 +659,31 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
};

auto emitNamedArgs =
[&](std::tuple<const Attribute &, const Attribute &> tuple)
-> LogicalResult {
Attribute attr = std::get<0>(tuple);
StringAttr argName = cast<StringAttr>(std::get<1>(tuple));

os << "/*" << argName.str() << "=*/";
return emitArgs(attr);
};

if (callOpaqueOp.getTemplateArgs()) {
os << "<";
if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
emitArgs)))
return failure();
if (callOpaqueOp.getTemplateArgNames() &&
!callOpaqueOp.getTemplateArgNames()->empty()) {
if (failed(interleaveCommaWithError(
llvm::zip_equal(*callOpaqueOp.getTemplateArgs(),
*callOpaqueOp.getTemplateArgNames()),
os, emitNamedArgs))) {
return failure();
}
} else {
if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os,
emitArgs)))
return failure();
}
os << ">";
}

Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Dialect/EmitC/invalid_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,27 @@ func.func @test_verbatim(%arg0 : !emitc.ptr<i32>, %arg1 : i32) {
emitc.verbatim "{a} " args %arg0, %arg1 : !emitc.ptr<i32>, i32
return
}

// -----

func.func @template_args_with_names(%arg0: i32) {
// expected-error @+1 {{'emitc.call_opaque' op number of template argument names must be equal to number of template arguments}}
emitc.call_opaque "kernel1"(%arg0) {template_arg_names = ["N", "P"], template_args = [42 : i32]} : (i32) -> ()
return
}

// -----

func.func @template_args_with_names(%arg0: i32) {
// expected-error @+1 {{'emitc.call_opaque' op number of template argument names must be equal to number of template arguments}}
emitc.call_opaque "kernel1"(%arg0) {template_arg_names = ["N"], template_args = [42 : i32, 56 : i32]} : (i32) -> ()
return
}

// -----

func.func @template_args_with_names(%arg0: i32) {
// expected-error @+1 {{'emitc.call_opaque' op should not have names for template arguments if it does not have template arguments}}
emitc.call_opaque "kernel1"(%arg0) {template_arg_names = ["N"]} : (i32) -> ()
return
}
7 changes: 7 additions & 0 deletions mlir/test/Dialect/EmitC/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,10 @@ func.func @member_access(%arg0: !emitc.opaque<"mystruct">, %arg1: !emitc.opaque<
%2 = "emitc.member_of_ptr" (%arg2) {member = "a"} : (!emitc.ptr<!emitc.opaque<"mystruct">>) -> i32
return
}

func.func @template_args_with_names(%arg0: i32, %arg1: f32) {
emitc.call_opaque "kernel1"(%arg0, %arg1) {template_arg_names = ["N", "P"], template_args = [42 : i32, 56]} : (i32, f32) -> ()
emitc.call_opaque "kernel2"(%arg0, %arg1) {template_arg_names = ["N"], template_args = [42 : i32]} : (i32, f32) -> ()
emitc.call_opaque "kernel3"(%arg0, %arg1) {template_arg_names = [], template_args = [#emitc.opaque<"42">]} : (i32, f32) -> ()
return
}
14 changes: 14 additions & 0 deletions mlir/test/Target/Cpp/template_arg_names.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT

// CPP-DEFAULT-LABEL: void basic
func.func @basic(%arg0: i32, %arg1: f32) {
emitc.call_opaque "kernel3"(%arg0, %arg1) : (i32, f32) -> ()
// CPP-DEFAULT: kernel3(
emitc.call_opaque "kernel4"(%arg0, %arg1) {template_arg_names = ["N", "P"], template_args = [42 : i32, 56]} : (i32, f32) -> ()
// CPP-DEFAULT: kernel4</*N=*/42, /*P=*/56>(
emitc.call_opaque "kernel4"(%arg0, %arg1) {template_arg_names = ["N"], template_args = [#emitc.opaque<"42">]} : (i32, f32) -> ()
// CPP-DEFAULT: kernel4</*N=*/42>(
return
}


0 comments on commit e25d207

Please sign in to comment.