|
|
|
@ -20,12 +20,6 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace imperative {
|
|
|
|
|
|
|
|
|
|
struct OpBaseCmp {
|
|
|
|
|
bool operator()(OpBase* first, OpBase* second) {
|
|
|
|
|
return first->id() > second->id();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static std::vector<std::unique_ptr<OpBase>> CreateGradOpBases(
|
|
|
|
|
const OpBase* fw_op_base, const NameVarBaseMap& in,
|
|
|
|
|
const NameVarBaseMap& out) {
|
|
|
|
@ -130,7 +124,7 @@ void Tracer::TraceBackward(const std::shared_ptr<OpBase>& fwd_op,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::set<OpBase*, OpBaseCmp> visited_preceding_ops;
|
|
|
|
|
std::set<OpBase*> visited_preceding_ops;
|
|
|
|
|
for (auto& grad_out_it : grad_out) {
|
|
|
|
|
bool flag_clear_list = false;
|
|
|
|
|
for (auto& var_base_it : grad_out_it.second) {
|
|
|
|
|