revert-3824-remove_grad_op_type
qiaolongfei 8 years ago
parent a240bce152
commit 252d41655a

@ -162,11 +162,8 @@ class OpRegistry {
auto op_create_it = op_creators().find(type); auto op_create_it = op_creators().find(type);
PADDLE_ENFORCE(op_create_it != op_creators().end(), PADDLE_ENFORCE(op_create_it != op_creators().end(),
"Operator %s cannot be found.", type); "Operator %s cannot be found.", type);
op_checkers().at(type).Check(attrs);
auto attrMap = attrs; auto op = op_create_it->second(type, inputs, outputs, attrs);
op_checkers().at(type).Check(attrMap);
auto op = op_create_it->second(type, inputs, outputs, attrMap);
GenerateTempVariableName(op);
return std::shared_ptr<OperatorBase>(op); return std::shared_ptr<OperatorBase>(op);
} }
@ -217,21 +214,6 @@ class OpRegistry {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_; static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_; return op_checkers_;
} }
static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : op->Outputs()) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
auto new_name = output_name;
new_name += op->Type();
new_name += "@";
new_name += std::to_string(gUniqId.fetch_add(1));
op->Rename(output_name, new_name);
}
}
}
}
}; };
class Registrar { class Registrar {

Loading…
Cancel
Save