diff --git a/xla/python/ifrt/ir/ifrt_ops.cc b/xla/python/ifrt/ir/ifrt_ops.cc index 080f0faf76e72..cc37e5010e084 100644 --- a/xla/python/ifrt/ir/ifrt_ops.cc +++ b/xla/python/ifrt/ir/ifrt_ops.cc @@ -248,12 +248,24 @@ mlir::LogicalResult VerifyIoAlias(mlir::Operation* op, IoAlias io_alias, return mlir::success(); } -mlir::LogicalResult VerifyIoAliases(mlir::Operation* op, - mlir::ArrayAttr io_aliases, - llvm::ArrayRef inputs, - llvm::ArrayRef outputs) { - llvm::SmallSet aliased_inputs; +mlir::LogicalResult VerifyIoAliasesAndDonations( + mlir::Operation* op, mlir::ArrayAttr io_aliases, + llvm::ArrayRef donated_input_indices, + llvm::ArrayRef inputs, + llvm::ArrayRef outputs) { + llvm::SmallSet aliased_or_donated_inputs; llvm::SmallSet aliased_outputs; + for (const int32_t donated_input_index : donated_input_indices) { + if (donated_input_index < 0 || donated_input_index >= inputs.size()) { + return op->emitOpError() + << "can't donate input #" << donated_input_index + << " as only having " << inputs.size() << " inputs"; + } + if (!aliased_or_donated_inputs.insert(donated_input_index).second) { + return op->emitOpError() << "can't donate input #" << donated_input_index + << " more than once"; + } + } for (const auto& raw_io_alias : io_aliases.getAsRange()) { llvm::ArrayRef io_alias_as_array = raw_io_alias.asArrayRef(); @@ -263,9 +275,9 @@ mlir::LogicalResult VerifyIoAliases(mlir::Operation* op, inputs, outputs))) { return mlir::failure(); } - if (!aliased_inputs.insert(aliased_input).second) { - return op->emitOpError() - << "can't alias input #" << aliased_input << " more than once"; + if (!aliased_or_donated_inputs.insert(aliased_input).second) { + return op->emitOpError() << "can't alias or donate input #" + << aliased_input << " more than once"; } if (!aliased_outputs.insert(aliased_output).second) { return op->emitOpError() @@ -618,8 +630,9 @@ mlir::LogicalResult CallOp::verify() { if (mlir::failed(VerifyDevicePlacement(*this, getDevices(), input_arrays, output_arrays)) || - mlir::failed(VerifyIoAliases(*this, getIoAliases(), input_arrays, - output_arrays))) { + mlir::failed(VerifyIoAliasesAndDonations(*this, getIoAliases(), + getDonatedInputIndices(), + input_arrays, output_arrays))) { return mlir::failure(); } return mlir::success(); @@ -680,7 +693,9 @@ mlir::LogicalResult CallLoadedExecutableOp::verify() { output_arrays.push_back(mlir::cast(output.getType())); } - return VerifyIoAliases(*this, getIoAliases(), input_arrays, output_arrays); + return VerifyIoAliasesAndDonations(*this, getIoAliases(), + getDonatedInputIndices(), input_arrays, + output_arrays); } mlir::LogicalResult LoadedExecutableOp::verify() { diff --git a/xla/python/ifrt/ir/ifrt_ops.td b/xla/python/ifrt/ir/ifrt_ops.td index 937cdf96ca6e3..1f4bde2c71015 100644 --- a/xla/python/ifrt/ir/ifrt_ops.td +++ b/xla/python/ifrt/ir/ifrt_ops.td @@ -182,9 +182,11 @@ def Ifrt_CallOp : Ifrt_Op<"Call", a subset of these devices. `io_aliases` represents pairs of inputs and outputs, where the input buffer - may be donated and used as the output buffer. The aliased pair must have the - same Ifrt_ArrayType. It's up to IFRT implementations whether to respect this - hint or not. + may be aliased and used as the output buffer. The aliased pair must have the + same byte size. It's up to IFRT implementations whether to respect this + hint or not. Alternatively, if the index of an input is In + `donated_input_indices` then the input buffer might be donated to the + callee if an output with the same byte size is found. }]; let arguments = (ins @@ -192,7 +194,8 @@ def Ifrt_CallOp : Ifrt_Op<"Call", Variadic:$control_inputs, SymbolRefAttr:$callee, Ifrt_DevicesAttr:$devices, - DefaultValuedAttr:$io_aliases); + DefaultValuedAttr:$io_aliases, + DefaultValuedAttr:$donated_input_indices); let results = (outs Variadic:$outputs, Ifrt_ControlType:$control_output); @@ -220,16 +223,19 @@ def Ifrt_CallLoadedExecutableOp : Ifrt_Op<"CallLoadedExecutable", be placed on a subset of these devices. `io_aliases` represents pairs of inputs and outputs, where the input buffer - may be donated and used as the output buffer. The aliased pair must have the - same Ifrt_ArrayType. It's up to IFRT implementations whether to respect this - hint or not. + may be aliased and used as the output buffer. The aliased pair must have the + same byte size. It's up to IFRT implementations whether to respect this + hint or not. Alternatively, if the index of an input is In + `donated_input_indices` then the input buffer might be donated to the + callee if an output with the same byte size is found. }]; let arguments = (ins Variadic:$inputs, Variadic:$control_inputs, SymbolRefAttr:$callee, - DefaultValuedAttr:$io_aliases); + DefaultValuedAttr:$io_aliases, + DefaultValuedAttr:$donated_input_indices); let results = (outs Variadic:$outputs, Ifrt_ControlType:$control_output); diff --git a/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir b/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir index 783ebb26bfdca..3e3728d864ed3 100644 --- a/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_populate_atom_program_metadata.mlir @@ -153,16 +153,17 @@ module @call_twice_with_different_sharding { !array = !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> -// CHECK-LABEL: @populate_io_alias -module @populate_io_alias { - func.func @main(%arg0: !array) attributes {ifrt.function} { - // CHECK: ifrt.Call @[[CALLEE_0:.+]]::@main(%arg0) - %0, %ctrl_0 = ifrt.Call @callee::@main(%arg0) on devices [0,1] - {io_aliases=[array]} : (!array) -> !array +// CHECK-LABEL: @populate_io_alias_and_donation +module @populate_io_alias_and_donation { + func.func @main(%arg0: !array, %arg1: !array) attributes {ifrt.function} { + // CHECK: ifrt.Call @[[CALLEE_0:.+]]::@main(%arg0, %arg1) + %0, %ctrl_0 = ifrt.Call @callee::@main(%arg0, %arg1) on devices [0,1] + {io_aliases=[array], donated_input_indices=array} + : (!array, !array) -> !array // Verify that the module is cloned if io_aliases differ. - // CHECK: ifrt.Call @[[CALLEE_1:.+]]::@main(%arg0) - %1, %ctrl_1 = ifrt.Call @callee::@main(%arg0) on devices [0,1] - : (!array) -> !array + // CHECK: ifrt.Call @[[CALLEE_1:.+]]::@main(%arg0, %arg1) + %1, %ctrl_1 = ifrt.Call @callee::@main(%arg0, %arg1) on devices [0,1] + : (!array, !array) -> !array return } @@ -188,8 +189,15 @@ module @populate_io_alias { // CHECK-DAG: ifrt.devices = #ifrt // CHECK-DAG: tf.aliasing_output = 0 : i32 // CHECK-SAME: } + // CHECK: %arg1: tensor<2x2xi32> + // CHECK-SAME: { + // CHECK-DAG: ifrt.sharding = #ifrt.sharding_param<2x1 to [0] on 2> + // CHECK-DAG: ifrt.devices = #ifrt + // CHECK-DAG: jax.buffer_donor = true + // CHECK-SAME: } module @callee attributes {sym_visibility = "private"} { - func.func private @main(%arg0: tensor<2x2xi32>) -> tensor<2x2xi32> { + func.func private @main(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) + -> tensor<2x2xi32> { return %arg0: tensor<2x2xi32> } } diff --git a/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir b/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir index 8c70318c03598..92bed2748c218 100644 --- a/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir +++ b/xla/python/ifrt/ir/tests/ifrt_verify_donation.mlir @@ -41,7 +41,7 @@ module @donate_to_reshard_duplicated_arg { // ----- !array = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> -module @donate_to_two_calls_error { +module @alias_to_two_calls_error { func.func @main(%arg0: !array {ifrt.donated}) -> (!array, !array) attributes {ifrt.function} { %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] @@ -59,13 +59,49 @@ module @donate_to_two_calls_error { // ----- +!array = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> +module @donate_to_two_calls_error { + func.func @main(%arg0: !array {ifrt.donated}) -> (!array, !array) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] + {donated_input_indices=array} : (!array) -> !array + // expected-error @+1 {{'ifrt.Call' op input #0 of @identity was already donated}} + %1, %ctrl_1 = ifrt.Call @identity(%arg0) on devices [0,1] + {donated_input_indices=array} : (!array) -> !array + return %0, %1 : !array, !array + } + + func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> { + return %arg0 : tensor<2xi32> + } +} + +// ----- + +!array = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> +module @arg_donated_to_call_not_donated_to_program { + func.func @main(%arg0: !array) -> (!array) + attributes {ifrt.function} { + // expected-error @+1 {{'ifrt.Call' op input #0 has not been donated to the program.}} + %0, %ctrl_0 = ifrt.Call @identity(%arg0) on devices [0,1] + {donated_input_indices=array} : (!array) -> !array + return %0 : !array + } + + func.func private @identity(%arg0: tensor<2xi32>) -> tensor<2xi32> { + return %arg0 : tensor<2xi32> + } +} + +// ----- + !array0 = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [0, 1]> !array1 = !ifrt.array, #ifrt.sharding_param<2 to [0] on 2>, [2, 3]> module @program_arg_not_donated_error { func.func @main(%arg0: !array0) -> (!array1) attributes {ifrt.function} { - // expected-error @+1 {{'ifrt.Reshard' op input has not been donated to the program.}} + // expected-error @+1 {{'ifrt.Reshard' op input #0 has not been donated to the program.}} %0, %ctrl_0 = ifrt.Reshard(%arg0) {donated=true} : (!array0) -> !array1 return %0 : !array1 } @@ -167,7 +203,7 @@ module @donate_to_two_copy_arrays_error { module @program_arg_not_donated_to_remap_error { func.func @main(%arg0: !array {ifrt.donated}, %arg1: !array) -> (!array) attributes {ifrt.function} { - // expected-error @+1 {{'ifrt.RemapArrays' op input has not been donated to the program.}} + // expected-error @+1 {{'ifrt.RemapArrays' op input #1 has not been donated to the program.}} %0 = ifrt.RemapArrays(%arg0, %arg1) mappings=[#ifrt.array_mapping<0, 0, [#ifrt.mapping<[0:1:1] to [0:1:1]>]>, #ifrt.array_mapping<1, 0, [#ifrt.mapping<[0:1:1] to [1:2:1]>]>] diff --git a/xla/python/ifrt/ir/tests/verify_call.mlir b/xla/python/ifrt/ir/tests/verify_call.mlir index e512b260600e7..202724e44496a 100644 --- a/xla/python/ifrt/ir/tests/verify_call.mlir +++ b/xla/python/ifrt/ir/tests/verify_call.mlir @@ -293,7 +293,7 @@ func.func @io_aliases_should_only_alias_input_once( %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.Call' op can't alias input #0 more than once}} + // expected-error@+1 {{'ifrt.Call' op can't alias or donate input #0 more than once}} %0, %1, %ctrl_0 = ifrt.Call @callee(%arg0) on devices [0,1] {io_aliases=[array, array]} : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, @@ -429,4 +429,57 @@ func.func @call_local_view_should_have_valid_shape( func.func @callee(%arg0: tensor<4x4xi32>) -> tensor<4x4xi32> { return %arg0 : tensor<4x4xi32> +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @donate_an_arg_and_alias_another(%arg0: !array, %arg1: !array) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.Call @callee(%arg0, %arg1) on devices [0,1] + {donated_input_indices=array, io_aliases=[array]} + : (!array, !array) -> !array + return +} + +func.func @callee(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) + -> tensor<2x2xi32> { + return %arg0 : tensor<2x2xi32> +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @should_only_donate_once(%arg0: !array, %arg1: !array) + attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.Call' op can't donate input #0 more than once}} + %0, %ctrl_0 = ifrt.Call @callee(%arg0, %arg1) on devices [0,1] + {donated_input_indices=array} + : (!array, !array) -> !array + return +} + +func.func @callee(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) + -> tensor<2x2xi32> { + return %arg0 : tensor<2x2xi32> +} + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @should_not_both_donate_and_alias_the_same_arg( + %arg0: !array, %arg1: !array) attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.Call' op can't alias or donate input #0 more than once}} + %0, %ctrl_0 = ifrt.Call @callee(%arg0, %arg1) on devices [0,1] + {donated_input_indices=array, io_aliases=[array]} + : (!array, !array) -> !array + return +} + +func.func @callee(%arg0: tensor<2x2xi32>, %arg1: tensor<2x2xi32>) + -> tensor<2x2xi32> { + return %arg0 : tensor<2x2xi32> } \ No newline at end of file diff --git a/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir b/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir index 14485f4c86a4e..e41add06877c6 100644 --- a/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir +++ b/xla/python/ifrt/ir/tests/verify_call_loaded_executable.mlir @@ -145,7 +145,7 @@ func.func @io_aliases_should_only_alias_input_once( %arg0: !ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]>) attributes {ifrt.function} { - // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias input #0 more than once}} + // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias or donate input #0 more than once}} %0, %1, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0) {io_aliases=[array, array]} : (!ifrt.array, #ifrt.sharding_param<1x1 to [0] on 2>, @@ -230,3 +230,47 @@ ifrt.LoadedExecutable @callee on devices [0,1] [0,1]>) -> !ifrt.array, #ifrt.sharding_param<2x1 to [0] on 2>, [0,1]> + + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @donate_one_arg_and_alias_another_arg(%arg0: !array, %arg1: !array) + attributes {ifrt.function} { + %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0, %arg1) + {donated_input_indices=array, io_aliases=[array]} + : (!array, !array) -> !array + return +} + +ifrt.LoadedExecutable @callee on devices [0,1] : (!array, !array) -> !array + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @should_only_donate_once(%arg0: !array, %arg1: !array) + attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't donate input #0 more than once}} + %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0, %arg1) + {donated_input_indices=array} : (!array, !array) -> !array + return +} + +ifrt.LoadedExecutable @callee on devices [0,1] : (!array, !array) -> !array + +// ----- + +!array = !ifrt.array, + #ifrt.sharding_param<1x1 to [0] on 2>, [0,1]> +func.func @should_not_both_donate_and_alias_the_same_arg( + %arg0: !array, %arg1: !array) attributes {ifrt.function} { + // expected-error@+1 {{'ifrt.CallLoadedExecutable' op can't alias or donate input #0 more than once}} + %0, %ctrl_0 = ifrt.CallLoadedExecutable @callee(%arg0, %arg1) + {donated_input_indices=array, io_aliases=[array]} + : (!array, !array) -> !array + return +} + +ifrt.LoadedExecutable @callee on devices [0,1] : (!array, !array) -> !array diff --git a/xla/python/ifrt/ir/transforms/ifrt_populate_atom_program_metadata_pass.cc b/xla/python/ifrt/ir/transforms/ifrt_populate_atom_program_metadata_pass.cc index 5b6e8268ac090..833658f923c92 100644 --- a/xla/python/ifrt/ir/transforms/ifrt_populate_atom_program_metadata_pass.cc +++ b/xla/python/ifrt/ir/transforms/ifrt_populate_atom_program_metadata_pass.cc @@ -159,6 +159,9 @@ mlir::LogicalResult PopulateMetadata(xla::ifrt::CallOp call_op, callee_op.setArgAttr(io_alias_as_array[0], "tf.aliasing_output", builder.getI32IntegerAttr(io_alias_as_array[1])); } + for (const auto idx : call_op.getDonatedInputIndices()) { + callee_op.setArgAttr(idx, "jax.buffer_donor", builder.getBoolAttr(true)); + } return mlir::success(); } diff --git a/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc b/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc index 7e3492147e166..7a4fcfdf16be9 100644 --- a/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc +++ b/xla/python/ifrt/ir/transforms/ifrt_verify_donation_pass.cc @@ -38,17 +38,100 @@ namespace { #include "xla/python/ifrt/ir/transforms/passes.h.inc" // Verifies that if the value is an input to the IR, then it has been donated. -mlir::LogicalResult VerifyIfInputAndDonated(mlir::Operation* op, +mlir::LogicalResult VerifyIfInputAndDonated(mlir::Operation* op, int idx, mlir::Value arg) { auto block_arg = mlir::dyn_cast(arg); mlir::func::FuncOp func_op = block_arg ? mlir::dyn_cast( block_arg.getOwner()->getParentOp()) : nullptr; - if (func_op && - func_op.getArgAttr(block_arg.getArgNumber(), - xla::ifrt::kIfrtDonatedArgAttrName) == nullptr) { - return op->emitOpError() << "input has not been donated to the program."; + if (func_op && func_op.getArgAttr(block_arg.getArgNumber(), + kIfrtDonatedArgAttrName) == nullptr) { + return op->emitOpError() + << "input #" << idx << " has not been donated to the program."; + } + return mlir::success(); +} + +template +mlir::LogicalResult verifyCallOpAliasesAndDonations( + T op, llvm::DenseMap& donated_value_to_op) { + llvm::DenseSet donated_input_idxs; + // Verify if a donated input is an argument of the main func, then it has + // also been donated by the user. + for (const auto idx : op.getDonatedInputIndices()) { + donated_input_idxs.insert(idx); + auto donated_value = op.getInputs()[idx]; + auto donated_it = donated_value_to_op.try_emplace(donated_value, op); + if (!donated_it.second) { + op.emitOpError() << "input #" << idx << " of " << op.getCalleeAttr() + << " was already donated or aliased to the op at " + << donated_it.first->second->getLoc(); + return mlir::failure(); + } + if (mlir::failed(VerifyIfInputAndDonated(op, idx, donated_value))) { + return mlir::failure(); + } + } + + for (const auto& io_alias : + op.getIoAliases().template getAsRange()) { + mlir::ArrayRef io_alias_as_array = io_alias.asArrayRef(); + donated_input_idxs.insert(io_alias_as_array[0]); + auto aliased_value = op.getInputs()[io_alias_as_array[0]]; + auto donated_it = donated_value_to_op.try_emplace(aliased_value, op); + if (!donated_it.second) { + op.emitOpError() << "input #" << io_alias_as_array[0] << " of " + << op.getCalleeAttr() + << " was already donated or aliased to the op at " + << donated_it.first->second->getLoc(); + return mlir::failure(); + } + if (mlir::failed( + VerifyIfInputAndDonated(op, io_alias_as_array[0], aliased_value))) { + return mlir::failure(); + } + } + + // Verify non-donated inputs after donated inputs have been + // added to also catch instances such as + // `ifrt.Call(%arg0 {ifrt.donated}, %arg0})`. + for (const auto [idx, input] : llvm::enumerate(op.getInputs())) { + if (!donated_input_idxs.contains(idx)) { + auto donated_it = donated_value_to_op.find(input); + if (donated_it != donated_value_to_op.end()) { + op.emitOpError() << "input #" << idx << " of " << op.getCalleeAttr() + << " was already donated to the op at " + << donated_it->second->getLoc(); + return mlir::failure(); + } + } + } + return mlir::success(); +} + +template +mlir::LogicalResult verifyCopyRemapAndReshardOpsDonation( + T op, llvm::DenseMap& donated_value_to_op) { + // Verify that no inputs have already been donated. + for (const auto [idx, input] : llvm::enumerate(op.getInputs())) { + auto donated_it = donated_value_to_op.find(input); + if (donated_it != donated_value_to_op.end()) { + op.emitOpError() << "input #" << idx << " of op at " << op.getLoc() + << " was already donated to the op at " + << donated_it->second->getLoc(); + return mlir::failure(); + } + } + if (op.getDonated()) { + // Add the donated inputs to the map and verify that all the + // donated inputs are also donated to the main func. + for (const auto [idx, input] : llvm::enumerate(op.getInputs())) { + donated_value_to_op.try_emplace(input, op); + if (mlir::failed(VerifyIfInputAndDonated(op, idx, input))) { + return mlir::failure(); + } + } } return mlir::success(); } @@ -74,72 +157,12 @@ void IfrtVerifyDonationPass::runOnOperation() { -> mlir::WalkResult { auto result = llvm::TypeSwitch(op) - .Case( - [&](auto& op) { - llvm::DenseSet donated_input_idxs; - for (const auto& io_alias : - op.getIoAliases() - .template getAsRange()) { - mlir::ArrayRef io_alias_as_array = - io_alias.asArrayRef(); - donated_input_idxs.insert(io_alias_as_array[0]); - auto donated_value = op.getInputs()[io_alias_as_array[0]]; - auto donated_it = - donated_value_to_op.try_emplace(donated_value, op); - if (!donated_it.second) { - op.emitOpError() << "input #" << io_alias_as_array[0] - << " of " << op.getCalleeAttr() - << " was already donated to the op at " - << donated_it.first->second->getLoc(); - return mlir::failure(); - } - if (mlir::failed( - VerifyIfInputAndDonated(op, donated_value))) { - return mlir::failure(); - } - } - // Verify non-donated inputs after donated inputs have been - // added to also catch instances such as - // `ifrt.Call(%arg0 {ifrt.donated}, %arg0})`. - for (const auto [idx, input] : - llvm::enumerate(op.getInputs())) { - if (!donated_input_idxs.contains(idx)) { - auto donated_it = donated_value_to_op.find(input); - if (donated_it != donated_value_to_op.end()) { - op.emitOpError() - << "input #" << idx << " of " << op.getCalleeAttr() - << " was already donated to the op at " - << donated_it->second->getLoc(); - return mlir::failure(); - } - } - } - return mlir::success(); - }) - .Case([&](auto& op) { - // Verify that no inputs have already been donated. - for (const auto [idx, input] : llvm::enumerate(op.getInputs())) { - auto donated_it = donated_value_to_op.find(input); - if (donated_it != donated_value_to_op.end()) { - op.emitOpError() - << "input #" << idx << " of op at " << op.getLoc() - << " was already donated to the op at " - << donated_it->second->getLoc(); - return mlir::failure(); - } - } - if (op.getDonated()) { - // Add the donated inputs to the map and verify that all the - // donated inputs are also donated to the main func. - for (const auto input : op.getInputs()) { - donated_value_to_op.try_emplace(input, op); - if (mlir::failed(VerifyIfInputAndDonated(op, input))) { - return mlir::failure(); - } - } - } - return mlir::success(); + .Case([&](auto& op) { + return verifyCallOpAliasesAndDonations(op, donated_value_to_op); + }) + .Case([&](auto& op) { + return verifyCopyRemapAndReshardOpsDonation(op, + donated_value_to_op); }) .Case([&](mlir::func::ReturnOp return_op) { for (const auto& [idx, result] : diff --git a/xla/python/ifrt/ir/transforms/passes.td b/xla/python/ifrt/ir/transforms/passes.td index a75fc059cb8e6..6acb62e678de2 100644 --- a/xla/python/ifrt/ir/transforms/passes.td +++ b/xla/python/ifrt/ir/transforms/passes.td @@ -214,6 +214,9 @@ For every CallOp, this pass main FuncOp 3. attaches `tf.aliasing_output` attr to the callee main FuncOp's inputs according to `io_aliases` + 4. attaches `jax.buffer_donor` attr to the callee main FuncOp's inputs + according to `donated_input_indices` + For CallOps with the same callee, a different clone will be created for each CallOp, even if the populated metadata are the same. User may want to run `ifrt-duplicated-callee-elimination` pass to dedup the clones.