|
|
|
@ -20,16 +20,14 @@ namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
enum class OpArgType { IN, OUT };
|
|
|
|
|
|
|
|
|
|
static void TransOpArg(const OperatorBase* src_op,
|
|
|
|
|
OperatorBase::VarNameMap* vars,
|
|
|
|
|
const OpArgType& src_type, bool is_grad) {
|
|
|
|
|
static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
|
|
|
|
|
bool is_grad, OperatorBase::VarNameMap* vars) {
|
|
|
|
|
const auto& src_inout =
|
|
|
|
|
src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_;
|
|
|
|
|
auto& dst_inout = *vars;
|
|
|
|
|
|
|
|
|
|
const OpProto& proto = OpProtos().at(src_op->type_);
|
|
|
|
|
const OpProto* proto = OpRegistry::op_info_map().at(src_op->type_).proto_;
|
|
|
|
|
const auto& src_arg_list =
|
|
|
|
|
src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
|
|
|
|
|
src_type == OpArgType::IN ? proto->inputs() : proto->outputs();
|
|
|
|
|
for (const auto& arg : src_arg_list) {
|
|
|
|
|
if (arg.no_gradient() && !is_grad) continue;
|
|
|
|
|
const std::string src_name = arg.name();
|
|
|
|
@ -43,22 +41,26 @@ static void TransOpArg(const OperatorBase* src_op,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OperatorBase* BuildGradOp(const OperatorBase* op) {
|
|
|
|
|
auto gop_type_it = OpRegistry::grad_ops().find(op->type_);
|
|
|
|
|
PADDLE_ENFORCE(gop_type_it != OpRegistry::grad_ops().end(),
|
|
|
|
|
"Operator %s do not register gradient type", op->type_);
|
|
|
|
|
auto& grad_op_type = gop_type_it->second;
|
|
|
|
|
auto it = OpRegistry::op_info_map().find(op->type_);
|
|
|
|
|
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
|
|
|
|
|
"'%s' has not been registered.", op->type_);
|
|
|
|
|
PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.",
|
|
|
|
|
op->type_);
|
|
|
|
|
std::string grad_op_type = it->second.grad_op_type_;
|
|
|
|
|
PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.",
|
|
|
|
|
op->type_);
|
|
|
|
|
|
|
|
|
|
OperatorBase::VarNameMap inputs;
|
|
|
|
|
OperatorBase::VarNameMap outputs;
|
|
|
|
|
TransOpArg(op, &inputs, OpArgType::IN, false); // I
|
|
|
|
|
TransOpArg(op, &inputs, OpArgType::OUT, false); // O
|
|
|
|
|
TransOpArg(op, &inputs, OpArgType::OUT, true); // OG
|
|
|
|
|
TransOpArg(op, &outputs, OpArgType::IN, true); // IG
|
|
|
|
|
auto gop_it = OpRegistry::op_creators().find(grad_op_type);
|
|
|
|
|
PADDLE_ENFORCE(gop_it != OpRegistry::op_creators().end(),
|
|
|
|
|
"Operator %s 's Gradient %s's creator cannot be found",
|
|
|
|
|
op->type_, grad_op_type);
|
|
|
|
|
TransOpArg(op, OpArgType::IN, false, &inputs); // I
|
|
|
|
|
TransOpArg(op, OpArgType::OUT, false, &inputs); // O
|
|
|
|
|
TransOpArg(op, OpArgType::OUT, true, &inputs); // OG
|
|
|
|
|
TransOpArg(op, OpArgType::IN, true, &outputs); // IG
|
|
|
|
|
|
|
|
|
|
return gop_it->second(grad_op_type, inputs, outputs, op->attrs_);
|
|
|
|
|
it = OpRegistry::op_info_map().find(grad_op_type);
|
|
|
|
|
PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(),
|
|
|
|
|
"'%s' has not been registered.", grad_op_type);
|
|
|
|
|
return it->second.creator_(grad_op_type, inputs, outputs, op->attrs_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|