|
|
|
@ -115,12 +115,32 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class SequencePoolGradOpMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
|
auto* op_desc_ptr = new framework::OpDesc();
|
|
|
|
|
op_desc_ptr->SetType("sequence_pool_grad");
|
|
|
|
|
op_desc_ptr->SetInput("X", Input("X"));
|
|
|
|
|
if (boost::get<std::string>(GetAttr("pooltype")) == "MAX") {
|
|
|
|
|
op_desc_ptr->SetInput("MaxIndex", Output("MaxIndex"));
|
|
|
|
|
}
|
|
|
|
|
op_desc_ptr->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
|
|
|
|
op_desc_ptr->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
|
|
|
|
op_desc_ptr->SetAttrMap(Attrs());
|
|
|
|
|
return std::unique_ptr<framework::OpDesc>(op_desc_ptr);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker,
|
|
|
|
|
sequence_pool_grad, ops::SequencePoolGradOp);
|
|
|
|
|
REGISTER_OPERATOR(sequence_pool, ops::SequencePoolOp, ops::SequencePoolOpMaker,
|
|
|
|
|
ops::SequencePoolGradOpMaker);
|
|
|
|
|
REGISTER_OPERATOR(sequence_pool_grad, ops::SequencePoolGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
sequence_pool,
|
|
|
|
|
ops::SequencePoolKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|