|
|
|
@ -49,9 +49,11 @@ static std::shared_ptr<OperatorBase> EmptyOp() {
|
|
|
|
|
return net_op;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<OperatorBase> BackwardImpl(
|
|
|
|
|
const OperatorBase& forwardOp,
|
|
|
|
|
std::unordered_set<std::string>& no_grad_names, int& uniq_id) {
|
|
|
|
|
static void DeDuplicate(NetOp* net, std::unordered_se)
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<OperatorBase> BackwardImpl(
|
|
|
|
|
const OperatorBase& forwardOp,
|
|
|
|
|
std::unordered_set<std::string>& no_grad_names, unsigned& uniq_id) {
|
|
|
|
|
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(),
|
|
|
|
|
no_grad_names)) {
|
|
|
|
|
return EmptyOp();
|
|
|
|
@ -70,6 +72,39 @@ static std::shared_ptr<OperatorBase> BackwardImpl(
|
|
|
|
|
|
|
|
|
|
if (forwardOp.IsNetOp()) {
|
|
|
|
|
//! TODO(dzh)
|
|
|
|
|
std::unordered_map<std::string, int> dup_output;
|
|
|
|
|
std::unordered_map<std::string std::vector<int>> dup_output_ops;
|
|
|
|
|
const unsigned uniq_id_local = uniq_id;
|
|
|
|
|
unsigned op_id_offset = 0;
|
|
|
|
|
for (auto& fwd : forwardOp) {
|
|
|
|
|
auto bwd = Backward(fwd, no_grad_names);
|
|
|
|
|
net->AddOp(bwd);
|
|
|
|
|
for (size_t i = 0; i < bwd.outputs_; ++i) {
|
|
|
|
|
bwd->outputs_[i] += OperatorBase::EMPTY_VAR_NAME();
|
|
|
|
|
if (dup_output.find(bwd->inputs_[i]) == dup_output.end()) {
|
|
|
|
|
dup_output[bwd->inputs_[i]] = 1;
|
|
|
|
|
dup_output_ops[bwd->inputs_[i]] = std::vector<int>{op_id_offset++};
|
|
|
|
|
} else {
|
|
|
|
|
dup_output[bwd->inputs_[i]]++;
|
|
|
|
|
dup_output_ops[bwd->inputs_[i]].emplace_back(op_id_offset++);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto dup : dup_output) {
|
|
|
|
|
if (dup.second == 1) continue;
|
|
|
|
|
auto op_ids = dup_output_ops.at(dup.first);
|
|
|
|
|
for (auto& op_id : op_ids) {
|
|
|
|
|
auto& op_ptr = net->ops_[op_id];
|
|
|
|
|
for (size_t i = 0; i < op_ptr->inputs_.size(); ++i) {
|
|
|
|
|
if (op_ptr->inputs_[i] == dup.first) {
|
|
|
|
|
// unique the duplicate name
|
|
|
|
|
op_ptr->inputs_[i] += std::to_string(uniq_id++);
|
|
|
|
|
// TODO(dzh): need a generic add op here
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
//! TODO(fjy)
|
|
|
|
|
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
|
|
|
|
|