Skip to content

Commit

Permalink
fix (#60625)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 authored Jan 9, 2024
1 parent 7a363e7 commit fbb5801
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ void ShareVarData(const Variable* src_var, Variable* dst_var) {
} else if (src_var->IsType<phi::TensorArray>()) {
auto src_tensor_array = src_var->Get<phi::TensorArray>();
auto* dst_tensor_array = dst_var->GetMutable<phi::TensorArray>();
if (src_tensor_array.numel() == 0) return;
if (src_tensor_array.size() == 0) return;
dst_tensor_array->clear();
for (auto src_tensor : src_tensor_array) {
phi::DenseTensor* tmp_dst_tensor = new phi::DenseTensor();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,21 @@ WhileInstruction::WhileInstruction(
GetInputIds(op, *parent_exe_info, &inputs);
auto body_outside_inputs =
GetExternalInputs(body_block_, *parent_exe_info, &inputs);
// NOTE(chenxi67): the variable corresponding to container value if a
// <VariableRefArray> Type. It will recursively get the ID of internal
// variables when use GetValueId() method. However, the copy_var pushed into
// the tuple does not have a corresponding ID, and will insert a -1. Here we
// remove the value of -1.
for (auto& item : inputs) {
auto& var_vec = item.second;
for (auto it = var_vec.begin(); it != var_vec.end();) {
if (*it == -1) {
it = var_vec.erase(it);
} else {
++it;
}
}
}
SetInputs(inputs);

std::unordered_map<pir::Value, std::vector<int>> outputs;
Expand Down

0 comments on commit fbb5801

Please sign in to comment.