|
|
|
@ -158,9 +158,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
|
}
|
|
|
|
|
insert_position.push_back(
|
|
|
|
|
{dup_op.back(),
|
|
|
|
|
OpRegistry::CreateOp(
|
|
|
|
|
"sum", {{"X", {insert_add_x}}, {"X", {insert_add_y}}},
|
|
|
|
|
{{"Out", {insert_add_out}}}, {})});
|
|
|
|
|
OpRegistry::CreateOp("sum", {{"X", {insert_add_x, insert_add_y}}},
|
|
|
|
|
{{"Out", {insert_add_out}}}, {})});
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -200,7 +199,8 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
|
|
|
|
|
|
// process recurrent gradient op as a special operator.
|
|
|
|
|
if (forwardOp.Type() == "recurrent") {
|
|
|
|
|
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or
|
|
|
|
|
// NOTE clean up cycle call somewhere (RNN's stepnet constains itself),
|
|
|
|
|
// or
|
|
|
|
|
// this will result in infinite loop.
|
|
|
|
|
const auto& rnnop =
|
|
|
|
|
*static_cast<const operators::RecurrentOp*>(&forwardOp);
|
|
|
|
|