|
|
|
@ -124,6 +124,9 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
|
std::list<Pos> insert_position;
|
|
|
|
|
for (auto& dup_output_op : dup_output_ops) {
|
|
|
|
|
const std::string& name = dup_output_op.first;
|
|
|
|
|
// duplicate @Empty@ don't need to be added
|
|
|
|
|
if (name == kEmptyVarName) continue;
|
|
|
|
|
|
|
|
|
|
auto& dup_op = dup_output_op.second;
|
|
|
|
|
// no duplicate output
|
|
|
|
|
if (dup_op.size() == 1) continue;
|
|
|
|
@ -209,7 +212,7 @@ std::unique_ptr<OperatorBase> Backward(
|
|
|
|
|
const OperatorBase& forwardOp,
|
|
|
|
|
const std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
|
std::unordered_set<std::string> no_grad_names;
|
|
|
|
|
no_grad_names.reserve(no_grad_vars.size());
|
|
|
|
|
no_grad_names.reserve(no_grad_vars.size() + 1);
|
|
|
|
|
|
|
|
|
|
no_grad_names.insert(std::string(kEmptyVarName) + kGradVarSuffix);
|
|
|
|
|
|
|
|
|
|