|
|
@ -127,11 +127,8 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
net->ops_[op_offset]->Rename(name, dup_outputs.back());
|
|
|
|
net->ops_[op_offset]->Rename(name, dup_outputs.back());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
insert_position.push_back(
|
|
|
|
insert_position.push_back(
|
|
|
|
{dup_op.back(),
|
|
|
|
{dup_op.back(), OpRegistry::CreateOp("add", {{"X", {dup_outputs}}},
|
|
|
|
OpRegistry::CreateOp(
|
|
|
|
{{"Out", {name}}}, {})});
|
|
|
|
"add", {{"X", {dup_outputs}}}, {{"Out", {name}}},
|
|
|
|
|
|
|
|
{{"input_format",
|
|
|
|
|
|
|
|
std::vector<int>{0, static_cast<int>(dup_outputs.size())}}})});
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
insert_position.sort(
|
|
|
|
insert_position.sort(
|
|
|
@ -140,7 +137,6 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
for (auto& pos : insert_position) {
|
|
|
|
for (auto& pos : insert_position) {
|
|
|
|
net->InsertOp(pos.first + 1, pos.second);
|
|
|
|
net->InsertOp(pos.first + 1, pos.second);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
|
|
|
|
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
|
|
|
|
|
|
|
|
|
|
|
@ -176,7 +172,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
net->type_ = "@GENERATED_BACKWARD@";
|
|
|
|
net->type_ = "@GENERATED_BACKWARD@";
|
|
|
|
net->CompleteAddOp();
|
|
|
|
net->CompleteAddOp();
|
|
|
|
return net;
|
|
|
|
return net;
|
|
|
|
}
|
|
|
|
} // namespace framework
|
|
|
|
|
|
|
|
|
|
|
|
// See header for comments
|
|
|
|
// See header for comments
|
|
|
|
std::shared_ptr<OperatorBase> Backward(
|
|
|
|
std::shared_ptr<OperatorBase> Backward(
|
|
|
|