|
|
|
@ -60,7 +60,8 @@ As an example:
|
|
|
|
|
|
|
|
|
|
Given:
|
|
|
|
|
|
|
|
|
|
X = [1, 2 , 3]
|
|
|
|
|
X.data = [1, 2 , 3, 4]
|
|
|
|
|
X.lod = [[0, 3, 4], [0, 1, 3, 4]]
|
|
|
|
|
|
|
|
|
|
and
|
|
|
|
|
|
|
|
|
@ -69,8 +70,8 @@ repeat = 2
|
|
|
|
|
|
|
|
|
|
then we get
|
|
|
|
|
|
|
|
|
|
Out.data = [1, 1, 2, 2, 3, 3]
|
|
|
|
|
Out.lod = [[0, 2, 4, 6]]
|
|
|
|
|
Out.data = [1, 2, 3, 1, 2, 3, 4, 4]
|
|
|
|
|
Out.lod = [[0, 6, 8], [0, 3, 6, 7, 8], [0, 1, 3, 4, 6, 7, 8]]
|
|
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
@ -83,6 +84,7 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
@ -93,30 +95,12 @@ class SeqExpandOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SeqExpandOpGradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::unique_ptr<framework::OpDescBind> Apply() const override {
|
|
|
|
|
auto* bind = new framework::OpDescBind();
|
|
|
|
|
bind->SetInput("X", Input("X"));
|
|
|
|
|
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
|
|
|
|
bind->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
|
|
|
|
bind->SetAttrMap(Attrs());
|
|
|
|
|
bind->SetType("seq_expand_grad");
|
|
|
|
|
return std::unique_ptr<framework::OpDescBind>(bind);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker,
|
|
|
|
|
ops::SeqExpandOpGradMaker);
|
|
|
|
|
REGISTER_OPERATOR(seq_expand_grad, ops::SeqExpandOpGrad);
|
|
|
|
|
REGISTER_OP(seq_expand, ops::SeqExpandOp, ops::SeqExpandOpMaker,
|
|
|
|
|
seq_expand_grad, ops::SeqExpandOpGrad);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(seq_expand,
|
|
|
|
|
ops::SeqExpandKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|