Remove input_format in backward.cc

revert-3824-remove_grad_op_type
Yu Yang 8 years ago
parent ffbb0be21f
commit 186fb0c118

@ -127,11 +127,8 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
net->ops_[op_offset]->Rename(name, dup_outputs.back()); net->ops_[op_offset]->Rename(name, dup_outputs.back());
} }
insert_position.push_back( insert_position.push_back(
{dup_op.back(), {dup_op.back(), OpRegistry::CreateOp("add", {{"X", {dup_outputs}}},
OpRegistry::CreateOp( {{"Out", {name}}}, {})});
"add", {{"X", {dup_outputs}}}, {{"Out", {name}}},
{{"input_format",
std::vector<int>{0, static_cast<int>(dup_outputs.size())}}})});
} }
insert_position.sort( insert_position.sort(
@ -140,7 +137,6 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
for (auto& pos : insert_position) { for (auto& pos : insert_position) {
net->InsertOp(pos.first + 1, pos.second); net->InsertOp(pos.first + 1, pos.second);
} }
} else { } else {
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp); std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
@ -176,7 +172,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
net->type_ = "@GENERATED_BACKWARD@"; net->type_ = "@GENERATED_BACKWARD@";
net->CompleteAddOp(); net->CompleteAddOp();
return net; return net;
} } // namespace framework
// See header for comments // See header for comments
std::shared_ptr<OperatorBase> Backward( std::shared_ptr<OperatorBase> Backward(

Loading…
Cancel
Save