Remove intermediate output's gradient from inputs of grad_op.

wangkuiyi-patch-2
wanghaoshuang 7 years ago
parent 387e10c6cd
commit 00548a1601

@ -194,13 +194,45 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
}
};
class GRUUnitGradOpMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("gru_unit_grad");
op->SetInput("Input", Input("Input"));
op->SetInput("HiddenPrev", Input("HiddenPrev"));
op->SetInput("Weight", Input("Weight"));
op->SetInput("Bias", Input("Bias"));
op->SetInput("Hidden", Output("Hidden"));
op->SetInput("Gate", Output("Gate"));
op->SetInput("ResetHiddenPrev", Output("ResetHiddenPrev"));
op->SetInput(framework::GradVarName("Hidden"), OutputGrad("Hidden"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
op->SetOutput(framework::GradVarName("HiddenPrev"),
InputGrad("HiddenPrev"));
op->SetOutput(framework::GradVarName("Weight"), InputGrad("Weight"));
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
return std::unique_ptr<framework::OpDesc>(op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(gru_unit, ops::GRUUnitOp, ops::GRUUnitOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>)
REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp)
ops::GRUUnitGradOpMaker);
REGISTER_OPERATOR(gru_unit_grad, ops::GRUUnitGradOp);
REGISTER_OP_CPU_KERNEL(
gru_unit, ops::GRUUnitKernel<paddle::platform::CPUDeviceContext, float>,
ops::GRUUnitKernel<paddle::platform::CPUDeviceContext, double>);

Loading…
Cancel
Save