fix backward

revert-3824-remove_grad_op_type
qiaolongfei 8 years ago
parent 5b7633a55f
commit a240bce152

@ -79,9 +79,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// All output gradients of forwarding operator do not need to calculate.
// Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.Output() /*names*/, kGradVarSuffix /*suffix*/,
if (AllInSet(forwardOp.Outputs() /*names*/, kGradVarSuffix /*suffix*/,
no_grad_names /*set*/)) {
ForEachVarName(forwardOp.inputs_,
ForEachVarName(forwardOp.Inputs(),
[&no_grad_names](const std::string& name) -> bool {
no_grad_names.insert(GradVarName(name));
return false;

Loading…
Cancel
Save