|
|
|
@ -282,15 +282,22 @@ All parameter, weight, gradient are variables in Paddle.
|
|
|
|
|
}
|
|
|
|
|
return ret_values;
|
|
|
|
|
});
|
|
|
|
|
m.def("get_grad_op_desc",
|
|
|
|
|
m.def("get_grad_op_descs",
|
|
|
|
|
[](const OpDescBind &op_desc,
|
|
|
|
|
const std::unordered_set<std::string> &no_grad_set,
|
|
|
|
|
std::unordered_map<std::string, std::string> &grad_to_var,
|
|
|
|
|
const std::vector<BlockDescBind *> &grad_sub_block) {
|
|
|
|
|
return framework::OpInfoMap::Instance()
|
|
|
|
|
.Get(op_desc.Type())
|
|
|
|
|
.GradOpMaker()(op_desc, no_grad_set, &grad_to_var,
|
|
|
|
|
grad_sub_block);
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs =
|
|
|
|
|
framework::OpInfoMap::Instance()
|
|
|
|
|
.Get(op_desc.Type())
|
|
|
|
|
.GradOpMaker()(op_desc, no_grad_set, &grad_to_var,
|
|
|
|
|
grad_sub_block);
|
|
|
|
|
std::vector<OpDescBind *> grad_op_desc_ptrs(grad_op_descs.size());
|
|
|
|
|
std::transform(
|
|
|
|
|
grad_op_descs.begin(), grad_op_descs.end(),
|
|
|
|
|
grad_op_desc_ptrs.begin(),
|
|
|
|
|
[](std::unique_ptr<OpDescBind> &p) { return p.release(); });
|
|
|
|
|
return grad_op_desc_ptrs;
|
|
|
|
|
});
|
|
|
|
|
m.def("prune", [](const ProgramDescBind &origin,
|
|
|
|
|
const std::vector<std::array<size_t, 2>> &targets) {
|
|
|
|
|