|
|
@ -60,19 +60,20 @@ class SumGradMaker : public framework::GradOpDescMakerBase {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
|
|
|
|
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<framework::OpDescBind> operator()() const override {
|
|
|
|
std::vector<std::unique_ptr<framework::OpDescBind>> operator()()
|
|
|
|
|
|
|
|
const override {
|
|
|
|
auto x_grads = InputGrad("X");
|
|
|
|
auto x_grads = InputGrad("X");
|
|
|
|
std::vector<framework::OpDescBind> grad_ops;
|
|
|
|
std::vector<std::unique_ptr<framework::OpDescBind>> grad_ops;
|
|
|
|
grad_ops.reserve(x_grads.size());
|
|
|
|
grad_ops.reserve(x_grads.size());
|
|
|
|
auto og = OutputGrad("Out");
|
|
|
|
auto og = OutputGrad("Out");
|
|
|
|
std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
|
|
|
|
std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
|
|
|
|
[&og](const std::string& x_grad) {
|
|
|
|
[&og](const std::string& x_grad) {
|
|
|
|
framework::OpDescBind grad_op;
|
|
|
|
auto* grad_op = new framework::OpDescBind();
|
|
|
|
grad_op.SetType("scale");
|
|
|
|
grad_op->SetType("scale");
|
|
|
|
grad_op.SetInput("X", og);
|
|
|
|
grad_op->SetInput("X", og);
|
|
|
|
grad_op.SetOutput("Out", {x_grad});
|
|
|
|
grad_op->SetOutput("Out", {x_grad});
|
|
|
|
grad_op.SetAttr("scale", 1.0f);
|
|
|
|
grad_op->SetAttr("scale", 1.0f);
|
|
|
|
return grad_op;
|
|
|
|
return std::unique_ptr<framework::OpDescBind>(grad_op);
|
|
|
|
});
|
|
|
|
});
|
|
|
|
return grad_ops;
|
|
|
|
return grad_ops;
|
|
|
|
}
|
|
|
|
}
|
|
|
|