|
|
|
@ -60,9 +60,14 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> OpRegistry::CreateGradOpDescs(
|
|
|
|
|
const OpDescBind& op_desc) {
|
|
|
|
|
auto& info = OpInfoMap::Instance().Get(op_desc.Type());
|
|
|
|
|
return info.grad_op_maker_(op_desc);
|
|
|
|
|
OpDescBind* op_desc) {
|
|
|
|
|
auto& info = OpInfoMap::Instance().Get(op_desc->Type());
|
|
|
|
|
|
|
|
|
|
if (info.Checker() != nullptr) {
|
|
|
|
|
info.Checker()->Check(*op_desc->MutableAttrMap());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return info.grad_op_maker_(*op_desc);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|