|
|
|
@ -120,8 +120,10 @@ class OpProtoAndCheckerMaker {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OpRegistry {
|
|
|
|
|
using OpCreator = std::function<OperatorBase*()>;
|
|
|
|
|
using VarNameMap = OperatorBase::VarNameMap;
|
|
|
|
|
using OpCreator = std::function<OperatorBase*(
|
|
|
|
|
const std::string& type, const VarNameMap& inputs,
|
|
|
|
|
const VarNameMap& outputs, const AttributeMap& attrs)>;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
template <typename OpType, typename ProtoMakerType>
|
|
|
|
@ -153,14 +155,9 @@ class OpRegistry {
|
|
|
|
|
PADDLE_ENFORCE(op_create_it != op_creators().end(),
|
|
|
|
|
"Operator %s cannot be found.", type);
|
|
|
|
|
|
|
|
|
|
auto op = op_create_it->second();
|
|
|
|
|
op->type_ = type;
|
|
|
|
|
op->inputs_ = inputs;
|
|
|
|
|
op->outputs_ = outputs;
|
|
|
|
|
|
|
|
|
|
op->attrs_ = attrs;
|
|
|
|
|
op_checkers().at(type).Check(op->attrs_);
|
|
|
|
|
|
|
|
|
|
auto attrMap = attrs;
|
|
|
|
|
op_checkers().at(type).Check(attrMap);
|
|
|
|
|
auto op = op_create_it->second(type, inputs, outputs, attrMap);
|
|
|
|
|
GenerateTempVariableName(op);
|
|
|
|
|
|
|
|
|
|
op->Init();
|
|
|
|
@ -217,12 +214,14 @@ class OpRegistry {
|
|
|
|
|
|
|
|
|
|
static void GenerateTempVariableName(OperatorBase* op) {
|
|
|
|
|
static std::atomic<size_t> gUniqId(0UL);
|
|
|
|
|
for (auto& output : op->outputs_) {
|
|
|
|
|
for (auto& output : op->Outputs()) {
|
|
|
|
|
for (auto& output_name : output.second) {
|
|
|
|
|
if (output_name == kTempVarName) {
|
|
|
|
|
output_name += op->type_;
|
|
|
|
|
output_name += "@";
|
|
|
|
|
output_name += std::to_string(gUniqId.fetch_add(1));
|
|
|
|
|
auto new_name = output_name;
|
|
|
|
|
new_name += op->Type();
|
|
|
|
|
new_name += "@";
|
|
|
|
|
new_name += std::to_string(gUniqId.fetch_add(1));
|
|
|
|
|
op->Rename(output_name, new_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|