|
|
|
@ -131,11 +131,9 @@ class CVMGradOpDescMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
|
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
|
|
|
|
|
op->SetType("cvm_grad");
|
|
|
|
|
op->SetInput("X", Input("X"));
|
|
|
|
|
op->SetInput("CVM", Input("CVM"));
|
|
|
|
|
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("CVM"), InputGrad("CVM"));
|
|
|
|
|
op->SetAttrMap(Attrs());
|
|
|
|
|
return op;
|
|
|
|
|
}
|
|
|
|
|