|
|
|
@ -199,6 +199,7 @@ Add a mark to which output is temporary is helpful for future optimization.
|
|
|
|
|
class OpRegistry {
|
|
|
|
|
using OpCreator = std::function<OperatorBase*()>;
|
|
|
|
|
using VarIndexMap = std::unordered_map<std::string, int>;
|
|
|
|
|
using VarNameList = std::vector<std::string>;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
template <typename OpType, typename ProtoMakerType>
|
|
|
|
@ -226,42 +227,51 @@ class OpRegistry {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static OperatorPtr CreateOp(const OpDesc& op_desc) {
|
|
|
|
|
//! Create a OpPtr by type.
|
|
|
|
|
std::string op_type = op_desc.type();
|
|
|
|
|
OperatorPtr op(creators().at(op_type)());
|
|
|
|
|
//! Fill op's data member. Not use constructor because it will be noising
|
|
|
|
|
//! for Op developer.
|
|
|
|
|
op->type_ = op_desc.type();
|
|
|
|
|
// set op's inputs_ from desc.
|
|
|
|
|
op->inputs_.reserve((size_t)op_desc.inputs_size());
|
|
|
|
|
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
|
|
|
|
|
std::back_inserter(op->inputs_));
|
|
|
|
|
// set op's outputs_ from desc.
|
|
|
|
|
op->outputs_.reserve((size_t)op_desc.outputs_size());
|
|
|
|
|
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
|
|
|
|
|
std::back_inserter(op->outputs_));
|
|
|
|
|
static OperatorPtr CreateOp(const std::string& type,
|
|
|
|
|
const VarNameList& inputs,
|
|
|
|
|
const VarNameList& outputs,
|
|
|
|
|
const AttributeMap& attrs) {
|
|
|
|
|
auto op_create_it = creators().find(type);
|
|
|
|
|
PADDLE_ENFORCE(op_create_it != creators().end(),
|
|
|
|
|
"Operator %s cannot be found", type);
|
|
|
|
|
|
|
|
|
|
//! Fill attrs, and validate attrs.
|
|
|
|
|
for (auto& attr : op_desc.attrs()) {
|
|
|
|
|
op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
|
|
|
|
|
}
|
|
|
|
|
op_checkers().at(op_type).Check(op->attrs_);
|
|
|
|
|
auto op = op_create_it->second();
|
|
|
|
|
op->type_ = type;
|
|
|
|
|
op->inputs_ = inputs;
|
|
|
|
|
op->outputs_ = outputs;
|
|
|
|
|
op->attrs_ = attrs;
|
|
|
|
|
op_checkers().at(type).Check(op->attrs_);
|
|
|
|
|
|
|
|
|
|
//! Convert Temporary variable name to an unique variable name.
|
|
|
|
|
GenerateTempVariableName(op.get());
|
|
|
|
|
GenerateTempVariableName(op);
|
|
|
|
|
|
|
|
|
|
//! set argument offsets stored in op.
|
|
|
|
|
{
|
|
|
|
|
auto var_index_it = VarIndexMaps().find(op_type);
|
|
|
|
|
auto var_index_it = VarIndexMaps().find(type);
|
|
|
|
|
if (var_index_it != VarIndexMaps().end()) {
|
|
|
|
|
op->in_out_idxs_ = var_index_it->second;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
//! Other op's custom Init for a complex Op. For simple Op, the Init
|
|
|
|
|
//! method do nothing.
|
|
|
|
|
|
|
|
|
|
op->Init();
|
|
|
|
|
return op;
|
|
|
|
|
return OperatorPtr(op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static OperatorPtr CreateOp(const OpDesc& op_desc) {
|
|
|
|
|
std::vector<std::string> inputs;
|
|
|
|
|
inputs.reserve((size_t)op_desc.inputs_size());
|
|
|
|
|
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
|
|
|
|
|
std::back_inserter(inputs));
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> outputs;
|
|
|
|
|
outputs.reserve((size_t)op_desc.outputs_size());
|
|
|
|
|
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
|
|
|
|
|
std::back_inserter(outputs));
|
|
|
|
|
|
|
|
|
|
AttributeMap attrs;
|
|
|
|
|
for (auto& attr : op_desc.attrs()) {
|
|
|
|
|
attrs[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return CreateOp(op_desc.type(), inputs, outputs, attrs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, OpProto>& protos() {
|
|
|
|
|