|
|
|
@ -27,9 +27,7 @@ class SeqExpandOp : public framework::OperatorWithKernel {
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"));
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"));
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ctx->HasInput("Y"),
|
|
|
|
|
"Input(Y) of SeqExpandOp should not be null while repeat == 0.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"));
|
|
|
|
|
framework::DDim out_dim;
|
|
|
|
|
out_dim = ctx->GetInputDim("Y");
|
|
|
|
|
ctx->ShareLoD("Y", "Out");
|
|
|
|
@ -43,14 +41,14 @@ class SeqExpandOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
framework::OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X",
|
|
|
|
|
"(Tensor or LoDTensor) The input('X') of this operator can be a "
|
|
|
|
|
"(Tensor or LoDTensor) The input(X) of this operator can be a "
|
|
|
|
|
"LoDTensor or a base Tensor.");
|
|
|
|
|
AddInput("Y",
|
|
|
|
|
"(LoDTensor)The reference input('Y') of seq_expand op."
|
|
|
|
|
"(LoDTensor)The reference input(Y) of seq_expand op."
|
|
|
|
|
"It must be a LoDTensor with k-level(k>0)."
|
|
|
|
|
"Input(X) will be expanded according to LOD of input(Y)."
|
|
|
|
|
"The element numbers of last level in input('Y') "
|
|
|
|
|
"must be equal to dims[0] of input('X').");
|
|
|
|
|
"The input(X) will be expanded according to LOD of input(Y)."
|
|
|
|
|
"The element numbers of last level in input(Y) "
|
|
|
|
|
"must be equal to dims[0] of input(X).");
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(LodTensor)The output of seq_expand op."
|
|
|
|
|
"The lod of output will be as same as input(Y)'s lod.");
|
|
|
|
@ -133,7 +131,7 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"));
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Out"));
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
"The input(Out@GRAD) should not be null");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|
if (ctx->HasOutput(x_grad_name)) {
|
|
|
|
|