|
|
|
@ -72,19 +72,26 @@ class MinusGradMaker : public framework::GradOpDescMakerBase {
|
|
|
|
|
std::vector<std::unique_ptr<framework::OpDescBind>> operator()()
|
|
|
|
|
const override {
|
|
|
|
|
std::vector<std::unique_ptr<framework::OpDescBind>> ops;
|
|
|
|
|
ops.resize(2);
|
|
|
|
|
|
|
|
|
|
ops[0].reset(new framework::OpDescBind());
|
|
|
|
|
ops[0]->SetType("scale");
|
|
|
|
|
ops[0]->SetInput("X", OutputGrad("Out"));
|
|
|
|
|
ops[0]->SetOutput("Out", InputGrad("X"));
|
|
|
|
|
ops[0]->SetAttr("scale", 1.0f);
|
|
|
|
|
|
|
|
|
|
ops[1].reset(new framework::OpDescBind());
|
|
|
|
|
ops[1]->SetType("scale");
|
|
|
|
|
ops[1]->SetInput("X", OutputGrad("Out"));
|
|
|
|
|
ops[1]->SetOutput("Out", InputGrad("Y"));
|
|
|
|
|
ops[1]->SetAttr("scale", -1.0f);
|
|
|
|
|
auto x_g = InputGrad("X");
|
|
|
|
|
if (!x_g.empty()) {
|
|
|
|
|
auto *x_g_op = new framework::OpDescBind();
|
|
|
|
|
x_g_op->SetType("scale");
|
|
|
|
|
x_g_op->SetInput("X", OutputGrad("Out"));
|
|
|
|
|
x_g_op->SetOutput("Out", x_g);
|
|
|
|
|
x_g_op->SetAttr("scale", 1.0f);
|
|
|
|
|
ops.emplace_back(x_g_op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto y_g = InputGrad("Y");
|
|
|
|
|
if (!y_g.empty()) {
|
|
|
|
|
auto *y_g_op = new framework::OpDescBind();
|
|
|
|
|
y_g_op->SetType("scale");
|
|
|
|
|
y_g_op->SetInput("X", OutputGrad("Out"));
|
|
|
|
|
y_g_op->SetOutput("Out", y_g);
|
|
|
|
|
y_g_op->SetAttr("scale", -1.0f);
|
|
|
|
|
ops.emplace_back(y_g_op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ops;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|