-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add sequence reshape operator #7662
Conversation
namespace ops = paddle::operators; | ||
REGISTER_OP_CUDA_KERNEL( | ||
sequence_reshape, | ||
ops::SequenceReshapeKernel<paddle::platform::CUDADeviceContext, float>); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You also need register double
type for sequence_reshape
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
then out is a LoDTensor: | ||
out.lod = [[0, 1, 3]] | ||
out.data = [[0.1, 0.2, 0.3, 0.4], | ||
[0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be written as an integer so that it will be better to see.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
"to 0 after reshaped.", | ||
i + 1); | ||
out_lod[0].push_back(out_lod[0].back() + offset); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that line 50~64 should be put in InferShape
. This code belongs to the input data validity checking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's ok to do this in the kernel.
auto x_dims = ctx->GetInputDim("X"); | ||
PADDLE_ENFORCE_EQ(x_dims.size(), 2U, "Rank of Input(X) should be 2."); | ||
int dimension = ctx->Attrs().Get<int>("new_dim"); | ||
ctx->SetOutputDim("Out", {x_dims[0], static_cast<int64_t>(dimension)}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output dim may be not {x_dims[0], dimension}. And the output dim can be computed in InferShape
.
auto& out_lod = *out->mutable_lod(); | ||
out_lod.resize(1); | ||
out_lod[0].clear(); | ||
out_lod[0].push_back(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if out_width
equals in_dims[1]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just do the copy.
p_in_data + in_offset, bytes, dev_ctx.stream()); | ||
#endif | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the description of the example, you need only copy input to output and reset out_lod
and out_dim
, but not so complex.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
} | ||
|
||
out->mutable_data<T>(context.GetPlace()); | ||
framework::Copy(*in, context.GetPlace(), out); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 65 can be placed on line 40, and out->mutable_data<T>(context.GetPlace());
can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems Copy
will invoke mutable_data of dest tensor, so L64 is not necessary.
} else { | ||
auto& out_lod = *out->mutable_lod(); | ||
out_lod.resize(1); | ||
out_lod[0].clear(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
push_back: this effectively increases the container size by one, which causes an automatic reallocation of the allocated storage space if -and only if- the new vector size surpasses the current vector capacity.
you can replace out_lod[0].clear();
with out_lod[0].resize(seq_num);
.
op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X")); | ||
op_desc_ptr->SetAttrMap(Attrs()); | ||
return std::unique_ptr<framework::OpDesc>(op_desc_ptr); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need override Apply
. You can use the default xxxGradOpMaker
.
You can refer this, and register op with REGISTER_OP
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think, adding GradOpMaker explicitly is not harmful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I agreed with you.
But it seems not necessary, and sequence_reshape_op
should consistent with other ops.
I think Apply
should only be overridden in the complex op, just like while_op
, recurrent_op
and so on, because the default GradOpMaker
does not meet these op's needs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default GradOpMaker will make the prototxt containing many unnecessary variables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, it is helpful for memory optimization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM+
op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X")); | ||
op_desc_ptr->SetAttrMap(Attrs()); | ||
return std::unique_ptr<framework::OpDesc>(op_desc_ptr); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, it is helpful for memory optimization.
Resolves #6678