|
|
|
@ -49,11 +49,9 @@ static std::shared_ptr<OperatorBase> EmptyOp() {
|
|
|
|
|
return net_op;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void DeDuplicate(NetOp* net, std::unordered_se)
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<OperatorBase> BackwardImpl(
|
|
|
|
|
const OperatorBase& forwardOp,
|
|
|
|
|
std::unordered_set<std::string>& no_grad_names, unsigned& uniq_id) {
|
|
|
|
|
static std::shared_ptr<OperatorBase> BackwardImpl(
|
|
|
|
|
const OperatorBase& forwardOp,
|
|
|
|
|
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
|
|
|
|
|
if (AllInSet(forwardOp.inputs_, OperatorBase::GRAD_VAR_SUFFIX(),
|
|
|
|
|
no_grad_names)) {
|
|
|
|
|
return EmptyOp();
|
|
|
|
@ -73,13 +71,16 @@ static void DeDuplicate(NetOp* net, std::unordered_se)
|
|
|
|
|
if (forwardOp.IsNetOp()) {
|
|
|
|
|
//! TODO(dzh)
|
|
|
|
|
std::unordered_map<std::string, int> dup_output;
|
|
|
|
|
std::unordered_map<std::string std::vector<int>> dup_output_ops;
|
|
|
|
|
const unsigned uniq_id_local = uniq_id;
|
|
|
|
|
unsigned op_id_offset = 0;
|
|
|
|
|
for (auto& fwd : forwardOp) {
|
|
|
|
|
auto bwd = Backward(fwd, no_grad_names);
|
|
|
|
|
std::unordered_map<std::string, std::vector<int>> dup_output_ops;
|
|
|
|
|
// const unsigned uniq_id_local = uniq_id;
|
|
|
|
|
int op_id_offset = 0;
|
|
|
|
|
// Because it is a net op, it can static_cast.
|
|
|
|
|
auto& forwardNet = static_cast<const NetOp&>(forwardOp);
|
|
|
|
|
|
|
|
|
|
for (auto& fwd : forwardNet.ops_) {
|
|
|
|
|
auto bwd = Backward(*fwd, no_grad_names);
|
|
|
|
|
net->AddOp(bwd);
|
|
|
|
|
for (size_t i = 0; i < bwd.outputs_; ++i) {
|
|
|
|
|
for (size_t i = 0; i < bwd->outputs_.size(); ++i) {
|
|
|
|
|
bwd->outputs_[i] += OperatorBase::EMPTY_VAR_NAME();
|
|
|
|
|
if (dup_output.find(bwd->inputs_[i]) == dup_output.end()) {
|
|
|
|
|
dup_output[bwd->inputs_[i]] = 1;
|
|
|
|
@ -138,7 +139,7 @@ extern std::shared_ptr<OperatorBase> Backward(
|
|
|
|
|
for (auto& name : no_grad_vars) {
|
|
|
|
|
no_grad_names.insert(name + OperatorBase::GRAD_VAR_SUFFIX());
|
|
|
|
|
}
|
|
|
|
|
int uid = 0;
|
|
|
|
|
size_t uid = 0;
|
|
|
|
|
return BackwardImpl(forwardOp, no_grad_names, uid);
|
|
|
|
|
}
|
|
|
|
|
} // namespace framework
|
|
|
|
|