|
|
|
@ -194,13 +194,45 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class GRUUnitGradOpMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
|
auto* op = new framework::OpDesc();
|
|
|
|
|
op->SetType("gru_unit_grad");
|
|
|
|
|
|
|
|
|
|
op->SetInput("Input", Input("Input"));
|
|
|
|
|
op->SetInput("HiddenPrev", Input("HiddenPrev"));
|
|
|
|
|
op->SetInput("Weight", Input("Weight"));
|
|
|
|
|
op->SetInput("Bias", Input("Bias"));
|
|
|
|
|
|
|
|
|
|
op->SetInput("Hidden", Output("Hidden"));
|
|
|
|
|
op->SetInput("Gate", Output("Gate"));
|
|
|
|
|
op->SetInput("ResetHiddenPrev", Output("ResetHiddenPrev"));
|
|
|
|
|
op->SetInput(framework::GradVarName("Hidden"), OutputGrad("Hidden"));
|
|
|
|
|
|
|
|
|
|
op->SetAttrMap(Attrs());
|
|
|
|
|
|
|
|
|
|
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("HiddenPrev"),
|
|
|
|
|
InputGrad("HiddenPrev"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("Weight"), InputGrad("Weight"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
|
|
|
|
|
return std::unique_ptr<framework::OpDesc>(op);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>)
|
|
|
|
|
REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp)
|
|
|
|
|
ops::GRUUnitGradOpMaker);
|
|
|
|
|
REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
gru_unit, ops::GRUUnitKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::GRUUnitKernel<paddle::platform::CPUDeviceContext, double>);
|
|
|
|
|