|
|
|
@ -79,9 +79,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
|
// All output gradients of forwarding operator do not need to calculate.
|
|
|
|
|
// Then all input gradients cannot be computed at all, and we put them into
|
|
|
|
|
// `no_grad_names` set. Return an NOP.
|
|
|
|
|
if (AllInSet(forwardOp.Output() /*names*/, kGradVarSuffix /*suffix*/,
|
|
|
|
|
if (AllInSet(forwardOp.Outputs() /*names*/, kGradVarSuffix /*suffix*/,
|
|
|
|
|
no_grad_names /*set*/)) {
|
|
|
|
|
ForEachVarName(forwardOp.inputs_,
|
|
|
|
|
ForEachVarName(forwardOp.Inputs(),
|
|
|
|
|
[&no_grad_names](const std::string& name) -> bool {
|
|
|
|
|
no_grad_names.insert(GradVarName(name));
|
|
|
|
|
return false;
|
|
|
|
|