|
|
|
@ -25,13 +25,13 @@ class SequencePadOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
|
"Input(X) of SequencePadOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("PadValue"),
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("PadValue"), true,
|
|
|
|
|
"Input(PadValue) of SequencePadOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
|
"Output(Out) of SequencePadOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Length"),
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Length"), true,
|
|
|
|
|
"Output(Length) of SequencePadOp should not be null.");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
@ -39,8 +39,9 @@ class SequencePadOp : public framework::OperatorWithKernel {
|
|
|
|
|
"The rank of Input(X) can't be less than 2.");
|
|
|
|
|
auto time_step_dims = framework::slice_ddim(x_dims, 1, x_dims.size());
|
|
|
|
|
auto pad_value_dims = ctx->GetInputDim("PadValue");
|
|
|
|
|
PADDLE_ENFORCE(pad_value_dims == framework::make_ddim({1}) ||
|
|
|
|
|
PADDLE_ENFORCE_EQ(pad_value_dims == framework::make_ddim({1}) ||
|
|
|
|
|
pad_value_dims == time_step_dims,
|
|
|
|
|
true,
|
|
|
|
|
"The Input(PadValue) must be a scalar or a tensor whose "
|
|
|
|
|
"shape equals to time steps in sequences");
|
|
|
|
|
|
|
|
|
@ -52,7 +53,8 @@ class SequencePadOp : public framework::OperatorWithKernel {
|
|
|
|
|
framework::Variable* x_var =
|
|
|
|
|
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
|
|
|
|
|
const auto& x_lod = x_var->Get<LoDTensor>().lod();
|
|
|
|
|
PADDLE_ENFORCE(!x_lod.empty(), "The Input(X) must hold lod info.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_lod.empty(), false,
|
|
|
|
|
"The Input(X) must hold lod info.");
|
|
|
|
|
const auto& x_lod_0 = x_lod[0];
|
|
|
|
|
PADDLE_ENFORCE_GE(x_lod_0.size(), 2,
|
|
|
|
|
"The Input(X)'s lod info is corrupted.");
|
|
|
|
@ -80,7 +82,7 @@ class SequencePadOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int> out_dims_vec{out_dim_0, padded_length};
|
|
|
|
|
std::vector<int> len_dims_vec{out_dim_0, 1};
|
|
|
|
|
std::vector<int> len_dims_vec{out_dim_0};
|
|
|
|
|
auto time_step_dims_vec = framework::vectorize<int>(time_step_dims);
|
|
|
|
|
out_dims_vec.insert(out_dims_vec.end(), time_step_dims_vec.begin(),
|
|
|
|
|
time_step_dims_vec.end());
|
|
|
|
@ -143,7 +145,7 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
then we get LoDTensor:
|
|
|
|
|
Out.data = [[a, b, 0, 0],
|
|
|
|
|
[c, d, e, 0]]
|
|
|
|
|
Length.data = [[2], [3]]
|
|
|
|
|
Length.data = [2, 3]
|
|
|
|
|
|
|
|
|
|
Case 2:
|
|
|
|
|
|
|
|
|
@ -157,7 +159,7 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
then we get LoDTensor:
|
|
|
|
|
Out.data = [[[a1, a2], [b1, b2], [0, 0]],
|
|
|
|
|
[[c1, c2], [d1, d2], [e1, e2]]]
|
|
|
|
|
Length.data = [[2], [3]]
|
|
|
|
|
Length.data = [2, 3]
|
|
|
|
|
|
|
|
|
|
Case 3:
|
|
|
|
|
|
|
|
|
@ -171,7 +173,7 @@ class SequencePadOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
then we get LoDTensor:
|
|
|
|
|
Out.data = [[[a1, a2], [b1, b2], [p1, p2]],
|
|
|
|
|
[[c1, c2], [d1, d2], [e1, e2]]]
|
|
|
|
|
Length.data = [[2], [3]]
|
|
|
|
|
Length.data = [2, 3]
|
|
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
@ -182,9 +184,10 @@ class SequencePadGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
|
"Input(X) of SequencePadGradOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
|
"Input(Out@GRAD) of SequencePadGradOp should not be null.");
|
|
|
|
|
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
|
|
|
|