You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
130 lines
4.5 KiB
130 lines
4.5 KiB
/**
|
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef INC_EXTERNAL_REGISTER_REGISTER_H_
|
|
#define INC_EXTERNAL_REGISTER_REGISTER_H_
|
|
|
|
#include <functional>
|
|
#include <initializer_list>
|
|
#include <map>
|
|
#include <memory>
|
|
#include <set>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <unordered_map>
|
|
#include <vector>
|
|
|
|
#include "graph/operator.h"
|
|
#include "register/register_error_codes.h"
|
|
#include "register/register_fmk_types.h"
|
|
#include "register/register_types.h"
|
|
|
|
using std::make_shared;
|
|
using std::map;
|
|
using std::pair;
|
|
using std::string;
|
|
using std::to_string;
|
|
using std::unique_ptr;
|
|
using std::vector;
|
|
|
|
namespace ge {
|
|
class Operator;
|
|
class TensorDesc;
|
|
class Tensor;
|
|
class TBEPluginManager;
|
|
} // namespace ge
|
|
|
|
namespace google {
|
|
namespace protobuf {
|
|
class Message;
|
|
}
|
|
} // namespace google
|
|
|
|
namespace domi {
|
|
Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op);
|
|
Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op,
|
|
std::map<std::string, std::pair<std::string, std::string>> dynamic_name_attr_value,
|
|
int in_pos = -1, int out_pos = -1);
|
|
Status AutoMappingSubgraphIndex(const ge::Graph &graph, const std::function<int(int data_index)> &input,
|
|
const std::function<int(int netoutput_index)> &output);
|
|
Status AutoMappingSubgraphIndex(const ge::Graph &graph,
|
|
const std::function<Status(int data_index, int &parent_input_index)> &input,
|
|
const std::function<Status(int netoutput_index, int &parent_output_index)> &output);
|
|
using google::protobuf::Message;
|
|
class OpRegistrationDataImpl;
|
|
|
|
using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>;
|
|
using FusionParseParamFunc =
|
|
std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>;
|
|
using ParseSubgraphFunc = std::function<Status(const std::string &subgraph_name, const ge::Graph &graph)>;
|
|
|
|
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData {
|
|
public:
|
|
OpRegistrationData(const std::string &om_optype);
|
|
|
|
~OpRegistrationData();
|
|
|
|
OpRegistrationData &FrameworkType(const domi::FrameworkType &fmk_type);
|
|
|
|
OpRegistrationData &OriginOpType(const std::initializer_list<std::string> &ori_optype_list);
|
|
|
|
OpRegistrationData &OriginOpType(const std::string &ori_optype);
|
|
|
|
OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn);
|
|
|
|
OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn);
|
|
|
|
OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn);
|
|
|
|
OpRegistrationData &ImplyType(const domi::ImplyType &imply_type);
|
|
|
|
OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue);
|
|
|
|
OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type);
|
|
|
|
domi::ImplyType GetImplyType() const;
|
|
std::string GetOmOptype() const;
|
|
std::set<std::string> GetOriginOpTypeSet() const;
|
|
domi::FrameworkType GetFrameworkType() const;
|
|
ParseParamFunc GetParseParamFn() const;
|
|
FusionParseParamFunc GetFusionParseParamFn() const;
|
|
ParseSubgraphFunc GetParseSubgraphPostFn() const;
|
|
|
|
private:
|
|
std::shared_ptr<OpRegistrationDataImpl> impl_;
|
|
friend class OpRegistry;
|
|
friend class OpRegistrationTbe;
|
|
friend class ge::TBEPluginManager;
|
|
};
|
|
|
|
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver {
|
|
public:
|
|
OpReceiver(OpRegistrationData ®_data);
|
|
~OpReceiver() {}
|
|
};
|
|
|
|
#define REGISTER_CUSTOM_OP(name) REGISTER_CUSTOM_OP_UNIQ_HELPER(__COUNTER__, name)
|
|
#define REGISTER_CUSTOM_OP_UNIQ_HELPER(ctr, name) REGISTER_CUSTOM_OP_UNIQ(ctr, name)
|
|
#define REGISTER_CUSTOM_OP_UNIQ(ctr, name) \
|
|
static OpReceiver register_op##ctr __attribute__((unused)) = OpRegistrationData(name)
|
|
} // namespace domi
|
|
|
|
namespace ge {
|
|
using OpRegistrationData = domi::OpRegistrationData;
|
|
using OpReceiver = domi::OpReceiver;
|
|
} // namespace ge
|
|
#endif // INC_EXTERNAL_REGISTER_REGISTER_H_
|