|
|
|
@ -162,11 +162,8 @@ class OpRegistry {
|
|
|
|
|
auto op_create_it = op_creators().find(type);
|
|
|
|
|
PADDLE_ENFORCE(op_create_it != op_creators().end(),
|
|
|
|
|
"Operator %s cannot be found.", type);
|
|
|
|
|
|
|
|
|
|
auto attrMap = attrs;
|
|
|
|
|
op_checkers().at(type).Check(attrMap);
|
|
|
|
|
auto op = op_create_it->second(type, inputs, outputs, attrMap);
|
|
|
|
|
GenerateTempVariableName(op);
|
|
|
|
|
op_checkers().at(type).Check(attrs);
|
|
|
|
|
auto op = op_create_it->second(type, inputs, outputs, attrs);
|
|
|
|
|
|
|
|
|
|
return std::shared_ptr<OperatorBase>(op);
|
|
|
|
|
}
|
|
|
|
@ -217,21 +214,6 @@ class OpRegistry {
|
|
|
|
|
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
|
|
|
|
|
return op_checkers_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void GenerateTempVariableName(OperatorBase* op) {
|
|
|
|
|
static std::atomic<size_t> gUniqId(0UL);
|
|
|
|
|
for (auto& output : op->Outputs()) {
|
|
|
|
|
for (auto& output_name : output.second) {
|
|
|
|
|
if (output_name == kTempVarName) {
|
|
|
|
|
auto new_name = output_name;
|
|
|
|
|
new_name += op->Type();
|
|
|
|
|
new_name += "@";
|
|
|
|
|
new_name += std::to_string(gUniqId.fetch_add(1));
|
|
|
|
|
op->Rename(output_name, new_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Registrar {
|
|
|
|
|