|
|
|
@ -133,8 +133,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
|
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
|
|
|
|
|
for (std::string& grad_input : grad_op->inputs_) {
|
|
|
|
|
if (no_grad_names.count(grad_input)) {
|
|
|
|
|
// +1 for \0
|
|
|
|
|
std::string prefix = grad_input.substr(
|
|
|
|
|
0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char));
|
|
|
|
|
0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
|
|
|
|
|
grad_input = prefix + kZeroVarSuffix;
|
|
|
|
|
|
|
|
|
|
// If part of input gradient of that operator is not calculated, fill
|
|
|
|
|