/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef INC_EXTERNAL_GRAPH_OPERATOR_H_ #define INC_EXTERNAL_GRAPH_OPERATOR_H_ #include #include #include #include #include #include "./ge_error_codes.h" #include "./inference_context.h" #include "./tensor.h" #ifndef USER_GE_LOGI #define USER_GE_LOGI(...) #endif // USER_GE_LOGI #ifndef USER_GE_LOGW #define USER_GE_LOGW(...) #endif // USER_GE_LOGW #ifndef USER_GE_LOGE #define USER_GE_LOGE(...) #endif // USER_GE_LOGE #define DYNAMIC_OUTPUT_TD_NUM(name) ("__dynamic_output_" + name + "_cnt") #define DYNAMIC_INPUT_TD_NUM(name) ("__dynamic_input_" + name + "_cnt") namespace ge { class Operator; class OperatorImpl; class NamedAttrs; class Graph; class AttrValue; using SubgraphBuilder = std::function; using OperatorImplPtr = std::shared_ptr; using OperatorPtr = std::shared_ptr; class OpIO; using OutHandler = std::shared_ptr; using InHandler = std::shared_ptr; using std::function; using std::shared_ptr; using std::string; class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { public: friend class OperatorImpl; friend class GraphBuilderImpl; using OpInt = int64_t; using OpFloat = float; using OpString = string; using OpBool = bool; using OpTensor = Tensor; using OpType = ge::DataType; using OpNamedAttrs = ge::NamedAttrs; using OpListInt = std::vector; using OpListFloat = std::vector; using OpListString = std::vector; using OpListBool = std::vector; using OpListTensor = std::vector; using OpBytes = std::vector; using OpListListInt = std::vector>; using OpListType = std::vector; using OpListNamedAttrs = std::vector; Operator() {} explicit Operator(const string &type); Operator(const string &name, const string &type); virtual ~Operator() = default; bool IsEmpty() const; string GetName() const; string GetOpType() const; // Only has one output index = 0 Operator &SetInput(const string &dst_name, const Operator &src_oprt); Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); Operator &AddControlInput(const Operator &src_oprt); graphStatus GetInputConstData(const string &dst_name, Tensor &data) const; TensorDesc GetInputDesc(const string &name) const; TensorDesc GetInputDesc(uint32_t index) const; int GetDynamicOutputNum(const string &name) const; int GetDynamicInputNum(const string &name) const; graphStatus TryGetInputDesc(const string &name, TensorDesc &tensor_desc) const; graphStatus UpdateInputDesc(const string &name, const TensorDesc &tensor_desc); TensorDesc GetOutputDesc(const string &name) const; TensorDesc GetOutputDesc(uint32_t index) const; graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc); TensorDesc GetDynamicInputDesc(const string &name, uint32_t index) const; graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); TensorDesc GetDynamicOutputDesc(const string &name, uint32_t index) const; graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); graphStatus InferShapeAndType(); void SetInferenceContext(const InferenceContextPtr &inference_context); InferenceContextPtr GetInferenceContext() const; graphStatus VerifyAllAttr(bool disable_common_verifier = false); size_t GetInputsSize() const; size_t GetOutputsSize() const; const std::map GetAllAttrNamesAndTypes() const; Operator &SetAttr(const string &name, int64_t attr_value); Operator &SetAttr(const string &name, int32_t attr_value); Operator &SetAttr(const string &name, uint32_t attr_value); graphStatus GetAttr(const string &name, int64_t &attr_value) const; graphStatus GetAttr(const string &name, int32_t &attr_value) const; graphStatus GetAttr(const string &name, uint32_t &attr_value) const; Operator &SetAttr(const string &name, const std::vector &attr_value); Operator &SetAttr(const string &name, const std::vector &attr_value); Operator &SetAttr(const string &name, const std::vector &attr_value); Operator &SetAttr(const string &name, std::initializer_list &&attr_value); graphStatus GetAttr(const string &name, std::vector &attr_value) const; graphStatus GetAttr(const string &name, std::vector &attr_value) const; graphStatus GetAttr(const string &name, std::vector &attr_value) const; Operator &SetAttr(const string &name, float attr_value); graphStatus GetAttr(const string &name, float &attr_value) const; Operator &SetAttr(const string &name, const std::vector &attr_value); graphStatus GetAttr(const string &name, std::vector &attr_value) const; Operator &SetAttr(const string &name, AttrValue &&attr_value); graphStatus GetAttr(const string &name, AttrValue &attr_value) const; Operator &SetAttr(const string &name, const string &attr_value); graphStatus GetAttr(const string &name, string &attr_value) const; Operator &SetAttr(const string &name, const std::vector &attr_value); graphStatus GetAttr(const string &name, std::vector &attr_value) const; Operator &SetAttr(const string &name, bool attr_value); graphStatus GetAttr(const string &name, bool &attr_value) const; Operator &SetAttr(const string &name, const std::vector &attr_value); graphStatus GetAttr(const string &name, std::vector &attr_value) const; Operator &SetAttr(const string &name, const Tensor &attr_value); graphStatus GetAttr(const string &name, Tensor &attr_value) const; Operator &SetAttr(const string &name, const std::vector &attr_value); graphStatus GetAttr(const string &name, std::vector &attr_value) const; // Bytes type Operator &SetAttr(const string &name, const OpBytes &attr_value); // Bytes type graphStatus GetAttr(const string &name, OpBytes &attr_value) const; Operator &SetAttr(const string &name, const std::vector> &attr_value); graphStatus GetAttr(const string &name, std::vector> &attr_value) const; Operator &SetAttr(const string &name, const std::vector &attr_value); graphStatus GetAttr(const string &name, std::vector &attr_value) const; Operator &SetAttr(const string &name, const ge::DataType &attr_value); graphStatus GetAttr(const string &name, ge::DataType &attr_value) const; // func type Operator &SetAttr(const string &name, const ge::NamedAttrs &attr_value); graphStatus GetAttr(const string &name, ge::NamedAttrs &attr_value) const; Operator &SetAttr(const string &name, const std::vector &attr_value); graphStatus GetAttr(const string &name, std::vector &attr_value) const; void BreakConnect() const; size_t GetSubgraphNamesCount() const; std::vector GetSubgraphNames() const; SubgraphBuilder GetSubgraphBuilder(const string &name) const; Graph GetSubgraph(const string &name) const; SubgraphBuilder GetDynamicSubgraphBuilder(const string &name, uint32_t index) const; Graph GetDynamicSubgraph(const string &name, uint32_t index) const; protected: void AttrRegister(const string &name, float attr_value); void AttrRegister(const string &name, const std::vector &attr_value); void AttrRegister(const string &name, int64_t attr_value); void AttrRegister(const string &name, const std::vector &attr_value); void AttrRegister(const string &name, const string &attr_value); void AttrRegister(const string &name, const std::vector &attr_value); void AttrRegister(const string &name, bool attr_value); void AttrRegister(const string &name, const std::vector &attr_value); void AttrRegister(const string &name, const Tensor &attr_value); void AttrRegister(const string &name, const std::vector &attr_value); void AttrRegister(const string &name, const OpBytes &attr_value); void AttrRegister(const string &name, const std::vector> &attr_value); void AttrRegister(const string &name, const std::vector &attr_value); void AttrRegister(const string &name, const ge::DataType &attr_value); void AttrRegister(const string &name, const ge::NamedAttrs &attr_value); void AttrRegister(const string &name, const std::vector &attr_value); explicit Operator(OperatorImplPtr &&op_impl); void InputRegister(const string &name); void OptionalInputRegister(const string &name); void InferFuncRegister(const std::function &func); void VerifierFuncRegister(const std::function &func); void InferFormatFuncRegister(const std::function &func); void OutputRegister(const string &name); void DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back = true); void DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index); void DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back = true); void RequiredAttrRegister(const string &name); graphStatus VerifyAll(); // Only has one output index = 0 Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt); Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name); void SubgraphRegister(const string &ir_name, bool dynamic); void SubgraphCountRegister(const string &ir_name, uint32_t count); void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder); private: Operator &SetInput(const string &dst_name, const OutHandler &out_handler); OutHandler GetOutput(const string &name) const; OperatorImplPtr GetOperatorImplPtr() const; OperatorImplPtr operator_impl_{nullptr}; graphStatus GetInputConstDataOut(const string &dst_name, Tensor &data) const; }; } // namespace ge #endif // INC_EXTERNAL_GRAPH_OPERATOR_H_