|
|
@ -144,31 +144,23 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
|
|
|
|
//
|
|
|
|
//
|
|
|
|
// one variable is shared between multiple operators.
|
|
|
|
// one variable is shared between multiple operators.
|
|
|
|
// insert add operator one by one, then add it to output
|
|
|
|
// insert add operator one by one, then add it to output
|
|
|
|
if (dup_outputs.size() == 2) {
|
|
|
|
for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1;
|
|
|
|
|
|
|
|
++output_idx) {
|
|
|
|
|
|
|
|
auto insert_add_x = dup_outputs[output_idx];
|
|
|
|
|
|
|
|
auto insert_add_y = dup_outputs[output_idx];
|
|
|
|
|
|
|
|
auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx);
|
|
|
|
|
|
|
|
// first add op inserted
|
|
|
|
|
|
|
|
if (output_idx == dup_outputs.size() - 2) {
|
|
|
|
|
|
|
|
insert_add_out = name;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (output_idx != 0) {
|
|
|
|
|
|
|
|
insert_add_y = name + "@SHARED@" + std::to_string(output_idx - 1);
|
|
|
|
|
|
|
|
}
|
|
|
|
insert_position.push_back(
|
|
|
|
insert_position.push_back(
|
|
|
|
{dup_op.back(),
|
|
|
|
{dup_op.back(),
|
|
|
|
OpRegistry::CreateOp(
|
|
|
|
OpRegistry::CreateOp(
|
|
|
|
"add", {{"X", {dup_outputs[0]}}, {"Y", {dup_outputs[1]}}},
|
|
|
|
"add", {{"X", {insert_add_x}}, {"Y", {insert_add_y}}},
|
|
|
|
{{"Out", {name}}}, {})});
|
|
|
|
{{"Out", {insert_add_out}}}, {})});
|
|
|
|
} else {
|
|
|
|
|
|
|
|
for (size_t output_idx = 0; output_idx < dup_outputs.size() - 1;
|
|
|
|
|
|
|
|
++output_idx) {
|
|
|
|
|
|
|
|
auto insert_add_x = dup_outputs[output_idx];
|
|
|
|
|
|
|
|
auto insert_add_y = dup_outputs[output_idx];
|
|
|
|
|
|
|
|
auto insert_add_out = name + "@SHARED@" + std::to_string(output_idx);
|
|
|
|
|
|
|
|
// first add op inserted
|
|
|
|
|
|
|
|
if (output_idx == dup_outputs.size() - 1) {
|
|
|
|
|
|
|
|
insert_add_out = name;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (output_idx != 0) {
|
|
|
|
|
|
|
|
insert_add_y = name + "@SHARED@" + std::to_string(output_idx-1);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
insert_position.push_back(
|
|
|
|
|
|
|
|
{dup_op.back(),
|
|
|
|
|
|
|
|
OpRegistry::CreateOp(
|
|
|
|
|
|
|
|
"add", {{"X", {insert_add_x}}, {"Y", {insert_add_y}}},
|
|
|
|
|
|
|
|
{{"Out", {insert_add_out}}}, {})});
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|