|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/backward.h"
|
|
|
|
|
#include <list>
|
|
|
|
|
#include "paddle/framework/net.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
|
|
|
|
@ -71,7 +72,7 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
|
|
|
|
|
return EmptyOp();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// auto* net = new NetOp();
|
|
|
|
|
auto* net = new NetOp();
|
|
|
|
|
|
|
|
|
|
if (forwardOp.IsNetOp()) {
|
|
|
|
|
//! TODO(dzh)
|
|
|
|
@ -93,29 +94,32 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
|
|
|
|
|
}
|
|
|
|
|
// unique the duplicate name
|
|
|
|
|
auto uid = uniq_id++;
|
|
|
|
|
std::unordered_map<size_t, OperatorBase> insert_postion;
|
|
|
|
|
// TODO(dzh): more comment
|
|
|
|
|
typedef std::pair<size_t, std::shared_ptr<OperatorBase>> Pos;
|
|
|
|
|
std::list<Pos> insert_postion;
|
|
|
|
|
for (auto& dup_output_op : dup_output_ops) {
|
|
|
|
|
std::string& name = dup_output_op.first;
|
|
|
|
|
const std::string& name = dup_output_op.first;
|
|
|
|
|
auto& dup_op = dup_output_op.second;
|
|
|
|
|
if (dup_op.size() == 1) continue;
|
|
|
|
|
std::vector<std::string> dup_outputs;
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < dup_op.size(); ++i) {
|
|
|
|
|
auto op_offset = dup_op[i];
|
|
|
|
|
net->ops_[op_offset].Rename(
|
|
|
|
|
name,
|
|
|
|
|
name + "@RENAME@" + std::to_string(uid) + "@" + std::to_string(i));
|
|
|
|
|
dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" +
|
|
|
|
|
std::to_string(i));
|
|
|
|
|
net->ops_[op_offset]->Rename(name, dup_outputs.back());
|
|
|
|
|
}
|
|
|
|
|
insert_postion[op_offset] =
|
|
|
|
|
OpRegistry::CreateOp("Add", {}, {dup_op->inputs_}, {});
|
|
|
|
|
net->AddOp("Add");
|
|
|
|
|
net->AddOp();
|
|
|
|
|
// process shared variable
|
|
|
|
|
// while(dup_op.size()) {
|
|
|
|
|
//
|
|
|
|
|
// AddOp(OpRegistry::CreateOp("generic_add", {dup_outputs},
|
|
|
|
|
// {dup_op->inputs_}, {}));
|
|
|
|
|
//}
|
|
|
|
|
insert_postion.push_back(
|
|
|
|
|
{dup_op.back(),
|
|
|
|
|
OpRegistry::CreateOp(
|
|
|
|
|
"Add", {dup_outputs}, {name},
|
|
|
|
|
{{"input_format",
|
|
|
|
|
std::vector<int>{0, (int)dup_outputs.size()}}})});
|
|
|
|
|
}
|
|
|
|
|
insert_postion.sort(
|
|
|
|
|
[](const Pos& l, const Pos& r) { return l.first > r.first; });
|
|
|
|
|
for (auto& pos : insert_postion) {
|
|
|
|
|
net->InsertOp(pos.first, pos.second);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|