|
|
|
@ -198,6 +198,7 @@ Add a mark to which output is temporary is helpful for future optimization.
|
|
|
|
|
|
|
|
|
|
class OpRegistry {
|
|
|
|
|
using OpCreator = std::function<OperatorBase*()>;
|
|
|
|
|
using VarIndexMap = std::unordered_map<std::string, int>;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
template <typename OpType, typename ProtoMakerType>
|
|
|
|
@ -212,6 +213,17 @@ class OpRegistry {
|
|
|
|
|
op_proto.IsInitialized(),
|
|
|
|
|
"Fail to initialize %s's OpProto, because %s is not initialized",
|
|
|
|
|
op_type, op_proto.InitializationErrorString());
|
|
|
|
|
|
|
|
|
|
VarIndexMaps()[op_type].reset(new VarIndexMap());
|
|
|
|
|
auto& varmap = *VarIndexMaps()[op_type];
|
|
|
|
|
int idx = 0;
|
|
|
|
|
for (auto& var : op_proto.inputs()) {
|
|
|
|
|
varmap[var.name()] = idx++;
|
|
|
|
|
}
|
|
|
|
|
idx = 0;
|
|
|
|
|
for (auto& var : op_proto.outputs()) {
|
|
|
|
|
varmap[var.name()] = idx++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static OperatorPtr CreateOp(const OpDesc& op_desc) {
|
|
|
|
@ -220,7 +232,6 @@ class OpRegistry {
|
|
|
|
|
OperatorPtr op(creators().at(op_type)());
|
|
|
|
|
//! Fill op's data member. Not use constructor because it will be noising
|
|
|
|
|
//! for Op developer.
|
|
|
|
|
const OpProto& op_proto = protos().at(op_type);
|
|
|
|
|
op->type_ = op_desc.type();
|
|
|
|
|
// set op's inputs_ from desc.
|
|
|
|
|
op->inputs_.reserve((size_t)op_desc.inputs_size());
|
|
|
|
@ -240,25 +251,31 @@ class OpRegistry {
|
|
|
|
|
//! Convert Temporary variable name to an unique variable name.
|
|
|
|
|
GenerateTempVariableName(op.get());
|
|
|
|
|
|
|
|
|
|
// set argument offsets stored in op.
|
|
|
|
|
CreateInOutOffsetMap(op, op_proto);
|
|
|
|
|
//! set argument offsets stored in op.
|
|
|
|
|
{
|
|
|
|
|
auto var_index_it = VarIndexMaps().find(op_type);
|
|
|
|
|
if (var_index_it != VarIndexMaps().end()) {
|
|
|
|
|
op->in_out_idxs_ = var_index_it->second;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
//! Other op's custom Init for a complex Op. For simple Op, the Init
|
|
|
|
|
//! method do nothing.
|
|
|
|
|
op->Init();
|
|
|
|
|
return op;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// init op.in_out_idxs_ to accelerate argument's offset lookup.
|
|
|
|
|
static void CreateInOutOffsetMap(OperatorPtr op, const OpProto& proto) {
|
|
|
|
|
op->CreateInOutOffsetMap(proto);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, OpProto>& protos() {
|
|
|
|
|
static std::unordered_map<std::string, OpProto> protos_;
|
|
|
|
|
return protos_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
|
|
|
|
|
VarIndexMaps() {
|
|
|
|
|
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>> maps_;
|
|
|
|
|
return maps_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void GenerateTempVariableName(OperatorBase* op) {
|
|
|
|
|
static std::atomic<size_t> gUniqId(0UL);
|
|
|
|
|
for (auto& outname : op->outputs_) {
|
|
|
|
|