|
|
|
@ -52,6 +52,11 @@ static std::shared_ptr<OperatorBase> EmptyOp() {
|
|
|
|
|
static std::shared_ptr<OperatorBase> BackwardImpl(
|
|
|
|
|
const OperatorBase& forwardOp,
|
|
|
|
|
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
|
|
|
|
|
// struct OpIdentity {
|
|
|
|
|
// size_t local_op_id;
|
|
|
|
|
// size_t op_output_offset;
|
|
|
|
|
// };
|
|
|
|
|
|
|
|
|
|
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(),
|
|
|
|
|
no_grad_names)) {
|
|
|
|
|
return EmptyOp();
|
|
|
|
@ -66,44 +71,51 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
|
|
|
|
|
return EmptyOp();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto* net = new NetOp();
|
|
|
|
|
// auto* net = new NetOp();
|
|
|
|
|
|
|
|
|
|
if (forwardOp.IsNetOp()) {
|
|
|
|
|
//! TODO(dzh)
|
|
|
|
|
std::unordered_map<std::string, int> dup_output;
|
|
|
|
|
std::unordered_map<std::string, std::vector<int>> dup_output_ops;
|
|
|
|
|
// const unsigned uniq_id_local = uniq_id;
|
|
|
|
|
int op_id_offset = 0;
|
|
|
|
|
std::unordered_map<std::string /*var name*/,
|
|
|
|
|
std::vector<size_t> /*op offs et*/>
|
|
|
|
|
dup_output_ops;
|
|
|
|
|
size_t local_op_id = 0;
|
|
|
|
|
// Because it is a net op, it can static_cast.
|
|
|
|
|
auto& forwardNet = static_cast<const NetOp&>(forwardOp);
|
|
|
|
|
|
|
|
|
|
// travesal subnet/op
|
|
|
|
|
for (auto& fwd : forwardNet.ops_) {
|
|
|
|
|
auto bwd = Backward(*fwd, no_grad_names);
|
|
|
|
|
net->AddOp(bwd);
|
|
|
|
|
for (size_t i = 0; i < bwd->outputs_.size(); ++i) {
|
|
|
|
|
bwd->outputs_[i] += OperatorBase::EMPTY_VAR_NAME();
|
|
|
|
|
if (dup_output.find(bwd->inputs_[i]) == dup_output.end()) {
|
|
|
|
|
dup_output[bwd->inputs_[i]] = 1;
|
|
|
|
|
dup_output_ops[bwd->inputs_[i]] = std::vector<int>{op_id_offset++};
|
|
|
|
|
} else {
|
|
|
|
|
dup_output[bwd->inputs_[i]]++;
|
|
|
|
|
dup_output_ops[bwd->inputs_[i]].emplace_back(op_id_offset++);
|
|
|
|
|
}
|
|
|
|
|
dup_output_ops[bwd->outputs_[i]].emplace_back(local_op_id);
|
|
|
|
|
}
|
|
|
|
|
local_op_id++;
|
|
|
|
|
}
|
|
|
|
|
for (auto dup : dup_output) {
|
|
|
|
|
if (dup.second == 1) continue;
|
|
|
|
|
auto op_ids = dup_output_ops.at(dup.first);
|
|
|
|
|
for (auto& op_id : op_ids) {
|
|
|
|
|
auto& op_ptr = net->ops_[op_id];
|
|
|
|
|
for (size_t i = 0; i < op_ptr->inputs_.size(); ++i) {
|
|
|
|
|
if (op_ptr->inputs_[i] == dup.first) {
|
|
|
|
|
// unique the duplicate name
|
|
|
|
|
op_ptr->inputs_[i] += std::to_string(uniq_id++);
|
|
|
|
|
// TODO(dzh): need a generic add op here
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// unique the duplicate name
|
|
|
|
|
auto uid = uniq_id++;
|
|
|
|
|
std::unordered_map<size_t, OperatorBase> insert_postion;
|
|
|
|
|
for (auto& dup_output_op : dup_output_ops) {
|
|
|
|
|
std::string& name = dup_output_op.first;
|
|
|
|
|
auto& dup_op = dup_output_op.second;
|
|
|
|
|
if (dup_op.size() == 1) continue;
|
|
|
|
|
std::vector<std::string> dup_outputs;
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < dup_op.size(); ++i) {
|
|
|
|
|
auto op_offset = dup_op[i];
|
|
|
|
|
net->ops_[op_offset].Rename(
|
|
|
|
|
name,
|
|
|
|
|
name + "@RENAME@" + std::to_string(uid) + "@" + std::to_string(i));
|
|
|
|
|
}
|
|
|
|
|
insert_postion[op_offset] =
|
|
|
|
|
OpRegistry::CreateOp("Add", {}, {dup_op->inputs_}, {});
|
|
|
|
|
net->AddOp("Add");
|
|
|
|
|
net->AddOp();
|
|
|
|
|
// process shared variable
|
|
|
|
|
// while(dup_op.size()) {
|
|
|
|
|
//
|
|
|
|
|
// AddOp(OpRegistry::CreateOp("generic_add", {dup_outputs},
|
|
|
|
|
// {dup_op->inputs_}, {}));
|
|
|
|
|
//}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|