|
|
|
@ -27,18 +27,18 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of SequenceExpandAsOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"),
|
|
|
|
|
"Input(Y) of SequenceExpandAsOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of SequenceExpandAsOp should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "SequenceExpandAs");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInputs("Y"), "Input", "Y", "SequenceExpandAs");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SequenceExpandAs");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto out_dims = x_dims;
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(x_dims.size(), 2,
|
|
|
|
|
"Dimension number of Input(X) should be at least 2.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Dimension number of Input(X) should be at least 2. "
|
|
|
|
|
"But received X's dimensions = %d, X's shape = [%s].",
|
|
|
|
|
x_dims.size(), x_dims));
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
framework::Variable* x_var =
|
|
|
|
@ -50,11 +50,17 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto& y_lod = y_var->Get<LoDTensor>().lod();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(y_lod.size(), 1,
|
|
|
|
|
"Level number of Input(Y)'s lod should be 1.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Level number of Input(Y)'s lod should be 1. But "
|
|
|
|
|
"received Y's lod level = %d.",
|
|
|
|
|
y_lod.size()));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dim[0]), y_lod[0].size() - 1,
|
|
|
|
|
"The first dimension of Input(X) should be equal "
|
|
|
|
|
"to the size of Input(Y)'s 0 level lod.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The first dimension of Input(X) should be one "
|
|
|
|
|
"less than the size of Input(Y)'s 0 level lod. But "
|
|
|
|
|
"received X's shape[0] = %d, Y's lod[0].size = %d.",
|
|
|
|
|
x_dim[0], y_lod[0].size()));
|
|
|
|
|
|
|
|
|
|
int64_t out_first_dim = 0;
|
|
|
|
|
if (y_lod[0].size() <= 1) {
|
|
|
|
@ -138,9 +144,9 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "SequenceExpandAsGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Out")), "Input",
|
|
|
|
|
"Out@GRAD", "SequenceExpandAsGrad");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|