|
|
|
@ -8,6 +8,8 @@ namespace framework {
|
|
|
|
|
class OpRegistry;
|
|
|
|
|
|
|
|
|
|
class GradOpCreator {
|
|
|
|
|
using VarIndexMap = std::unordered_map<std::string, int>;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
GradOpCreator(const OperatorBase* op) : op_(op) {}
|
|
|
|
|
OperatorBase* Create();
|
|
|
|
@ -32,15 +34,15 @@ class GradOpCreator {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
OpInOutArg* BuildArg(const VarProto& var, const VarIndexMap& var_map,
|
|
|
|
|
const vector<int>& format, InOutType type);
|
|
|
|
|
const std::vector<int>& format, InOutType type);
|
|
|
|
|
void BuildOpInOutArgList();
|
|
|
|
|
void PushArgIntoGradOp(const OpInOutArg* arg, vector<std::string>& in_out,
|
|
|
|
|
vector<int>& format, VarIndexMap* varmap, int& idx,
|
|
|
|
|
bool is_grad);
|
|
|
|
|
void AddArgIntoGradOp(const OpInOutArg* arg, std::vector<std::string>& in_out,
|
|
|
|
|
std::vector<int>& format, VarIndexMap* varmap, int& idx,
|
|
|
|
|
bool is_grad);
|
|
|
|
|
void CompleteGradOp(OperatorBase* grad_op) const;
|
|
|
|
|
const OperatorBase* op_;
|
|
|
|
|
std::vector<std::shared_ptr<OpInOutArg>> arg_list_;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|