|
|
|
@ -17,7 +17,7 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using framework::Tensor;
|
|
|
|
|
using framework::LoDTensor;
|
|
|
|
|
|
|
|
|
|
class SequenceExpandOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
@ -25,15 +25,67 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"));
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"));
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"));
|
|
|
|
|
framework::DDim out_dim;
|
|
|
|
|
auto y_dim = ctx->GetInputDim("Y");
|
|
|
|
|
out_dim = ctx->GetInputDim("X");
|
|
|
|
|
out_dim[0] = y_dim[0];
|
|
|
|
|
ctx->ShareLoD("Y", "Out");
|
|
|
|
|
ctx->SetOutputDim("Out", out_dim);
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of SequenceExpandOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Y"),
|
|
|
|
|
"Input(Y) of SequenceExpandOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of SequenceExpandOp should not be null.");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims.size(), 2U,
|
|
|
|
|
"Dimension number of Input(X) should be 2.");
|
|
|
|
|
int ref_level = ctx->Attrs().Get<int>("ref_level");
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
framework::Variable* x_var =
|
|
|
|
|
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("X")[0]);
|
|
|
|
|
framework::Variable* y_var =
|
|
|
|
|
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Y")[0]);
|
|
|
|
|
|
|
|
|
|
auto& x_lod = x_var->Get<LoDTensor>().lod();
|
|
|
|
|
auto& y_lod = y_var->Get<LoDTensor>().lod();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_LE(x_lod.size(), 1,
|
|
|
|
|
"Number of lod level of Input(X) should not be "
|
|
|
|
|
"greater than 1.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(x_lod.size() == y_lod.size() || x_lod.size() == 0,
|
|
|
|
|
"Number of lod level of Input(X) either equal to 0 "
|
|
|
|
|
"or equal to that of Input(Y).");
|
|
|
|
|
|
|
|
|
|
int64_t out_first_dim = 0;
|
|
|
|
|
if (y_lod[ref_level].size() < 1) {
|
|
|
|
|
out_first_dim = x_dims[0];
|
|
|
|
|
} else {
|
|
|
|
|
if (x_lod.size() == 1) { // X is LoDTensor
|
|
|
|
|
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
|
|
|
|
|
int x_seq_len = x_lod[0][i] - x_lod[0][i - 1];
|
|
|
|
|
out_first_dim +=
|
|
|
|
|
(y_lod[ref_level][i] - y_lod[ref_level][i - 1]) * x_seq_len;
|
|
|
|
|
}
|
|
|
|
|
} else { // X is normal Tensor
|
|
|
|
|
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
|
|
|
|
|
out_first_dim += y_lod[ref_level][i] - y_lod[ref_level][i - 1];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("Out", {out_first_dim, x_dims[1]});
|
|
|
|
|
} else {
|
|
|
|
|
framework::VarDesc* in_reader =
|
|
|
|
|
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Y")[0]);
|
|
|
|
|
int lod_level_num = in_reader->GetLoDLevels().size();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(ref_level, 0,
|
|
|
|
|
"Level of referred lod should be greater or "
|
|
|
|
|
"equal to 0.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_LT(ref_level, lod_level_num,
|
|
|
|
|
"Level of referred lod should be smaller than "
|
|
|
|
|
"level number of Input(Y).");
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Out", {-1, x_dims[1]});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -42,17 +94,15 @@ class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
SequenceExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X",
|
|
|
|
|
"(Tensor or LoDTensor) The input(X) of this operator can be a "
|
|
|
|
|
"LoDTensor or a base Tensor.");
|
|
|
|
|
"(LoDTensor, default LoDTensor<float>) A 2-D LoDTensor whose lod "
|
|
|
|
|
"level is at most 1.");
|
|
|
|
|
AddInput("Y",
|
|
|
|
|
"(LoDTensor)The reference input(Y) of sequence_expand op."
|
|
|
|
|
"It must be a LoDTensor with k-level(k>0)."
|
|
|
|
|
"The input(X) will be expanded according to LOD of input(Y)."
|
|
|
|
|
"The element numbers of last level in input(Y) "
|
|
|
|
|
"must be equal to dims[0] of input(X).");
|
|
|
|
|
"(LoDTensor, default LoDTensor<float>) Referred LoDTensor whose "
|
|
|
|
|
"lod (specified level) is referred by Input(X).");
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(LodTensor)The output of sequence_expand op."
|
|
|
|
|
"The lod of output will be as same as input(Y)'s lod.");
|
|
|
|
|
"(LodTensor, default LoDTensor<float>) Output LoDTensor which is "
|
|
|
|
|
"generated from Input(X) by referring lod of Input(Y).");
|
|
|
|
|
AddAttr<int>("ref_level", "Specify lod level of Input(Y).");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Sequence Expand Operator.
|
|
|
|
|
|
|
|
|
@ -129,12 +179,14 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"));
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Out"));
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"The input(Out@GRAD) should not be null");
|
|
|
|
|
"Input(Out@GRAD) should not be null.");
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|
|
|
|
|
|
if (ctx->HasOutput(x_grad_name)) {
|
|
|
|
|
ctx->SetOutputDim(x_grad_name, x_dims);
|
|
|
|
|
}
|
|
|
|
@ -149,7 +201,13 @@ REGISTER_OP(sequence_expand, ops::SequenceExpandOp, ops::SequenceExpandOpMaker,
|
|
|
|
|
sequence_expand_grad, ops::SequenceExpandOpGrad);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
sequence_expand,
|
|
|
|
|
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, int>,
|
|
|
|
|
ops::SequenceExpandKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
sequence_expand_grad,
|
|
|
|
|
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, int>,
|
|
|
|
|
ops::SequenceExpandGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
|
|
|
|
|