!38 Synchronize with latest Ascend software suite 17 Jun 2020

Merge pull request !38 from yanghaoran/master
pull/38/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 1350673d51

@ -42,7 +42,7 @@ class BlockingQueue {
return false;
}
item = queue_.front();
item = std::move(queue_.front());
queue_.pop_front();
full_cond_.notify_one();
@ -71,6 +71,27 @@ class BlockingQueue {
return true;
}
bool Push(T &&item, bool is_wait = true) {
std::unique_lock<std::mutex> lock(mutex_);
while (queue_.size() >= max_size_ && !is_stoped_) {
if (!is_wait) {
return false;
}
full_cond_.wait(lock);
}
if (is_stoped_) {
return false;
}
queue_.emplace_back(std::move(item));
empty_cond_.notify_one();
return true;
}
void Stop() {
{
std::unique_lock<std::mutex> lock(mutex_);

@ -26,6 +26,7 @@ using std::string;
namespace ge {
// when need to eliminate GETaskKernelHcclInfo, so not need DAVINCI_TRAIN/DAVINCI_CLOUD
struct GETaskKernelHcclInfo {
string input_name;
string hccl_type;
void *inputDataAddr;
void *outputDataAddr;
@ -35,6 +36,7 @@ struct GETaskKernelHcclInfo {
int32_t opType;
int64_t rootId;
uint64_t workSpaceMemSize;
std::vector<int64_t> dims;
std::vector<rtStream_t> hcclStreamList;
};
@ -48,7 +50,7 @@ struct GETaskInfo {
uint32_t privateDefLen;
void *opsKernelStorePtr;
GETaskKernelHcclInfo kernelHcclInfo;
std::vector<GETaskKernelHcclInfo> kernelHcclInfo;
};
} // namespace ge
#endif // INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_

@ -73,7 +73,7 @@ class OpsKernelInfoStore {
// only call fe engine interface to compile single op
virtual Status CompileOp(vector<ge::NodePtr> &node_vec) { return SUCCESS; }
virtual Status CompileOpRun(vector<ge::NodePtr> &node_vec) { return SUCCESS; }
// load task for op
virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; }

@ -33,6 +33,7 @@ const char *const OPTION_EXEC_SESSION_ID = "ge.exec.sessionId";
const char *const OPTION_EXEC_DEVICE_ID = "ge.exec.deviceId";
const char *const OPTION_EXEC_JOB_ID = "ge.exec.jobId";
const char *const OPTION_EXEC_IS_USEHCOM = "ge.exec.isUseHcom";
const char *const OPTION_EXEC_IS_USEHVD = "ge.exec.isUseHvd";
const char *const OPTION_EXEC_RANK_ID = "ge.exec.rankId";
const char *const OPTION_EXEC_POD_NAME = "ge.exec.podName";
const char *const OPTION_EXEC_DEPLOY_MODE = "ge.exec.deployMode";
@ -52,6 +53,7 @@ const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions";
const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag";
const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic";
const char *const OPTION_EXEC_DISABLE_REUSED_MEMORY = "ge.exec.disableReuseMemory";
const char *const OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION = "ge.exec.isTailingOptimization";
// Option key: memory init
const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize";
@ -153,7 +155,7 @@ const std::string STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum";
const std::string OUTPUT_DATATYPE = "ge.outputDatatype";
// congigure opSelectImplmode to setting op select implmode
const std::string kOpSelectImplmode = "ge.opSelectImplmode";
const std::string OP_SELECT_IMPL_MODE = "ge.opSelectImplmode";
// configure whether to enable hcom parallel by session constructor options param,
// its value should be "0" or "1", default value is "0"
@ -214,6 +216,9 @@ const char *const ENABLE_PRINT_OP_PASS = "ge.enablePrintOpPass";
// Its value should be "true" or "false", default value is "false"
const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream";
// Configure input fp16 nodes
const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16";
// Graph run mode
enum GraphRunMode { PREDICTION = 0, TRAIN };
@ -263,14 +268,37 @@ static const char *const AUTO_TUNE_MODE = ge::AUTO_TUNE_MODE.c_str();
static const char *const CORE_TYPE = ge::CORE_TYPE.c_str();
static const char *const SOC_VERSION = ge::SOC_VERSION.c_str();
static const char *const ENABLE_SINGLE_STREAM = ge::ENABLE_SINGLE_STREAM;
static const char *const AICORE_NUM = ge::AICORE_NUM.c_str();
static const char *const FUSION_SWITCH_FILE = ge::FUSION_SWITCH_FILE.c_str();
static const char *const ENABLE_SMALL_CHANNEL = ge::ENABLE_SMALL_CHANNEL.c_str();
static const char *const QUANT_OPTIMIZE = ge::QUANT_OPTIMIZE.c_str();
static const char *const OP_SELECT_IMPL_MODE = ge::OP_SELECT_IMPL_MODE.c_str();
static const char *const OUTPUT_TYPE = ge::OUTPUT_DATATYPE.c_str();
static const char *const BUFFER_OPTIMIZE = ge::BUFFER_OPTIMIZE.c_str();
static const char *const ENABLE_COMPRESS_WEIGHT = ge::ENABLE_COMPRESS_WEIGHT.c_str();
static const char *const COMPRESS_WEIGHT_CONF = "compress_weight_conf";
static const char *const OUT_NODES = ge::OUTPUT_NODE_NAME.c_str();
static const char *const INPUT_FP16_NODES = ge::INPUT_FP16_NODES.c_str();
static const char *const LOG_LEVEL = "log";
// for interface: aclgrphBuildModel
const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT, INPUT_SHAPE, DYNAMIC_BATCH_SIZE,
DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE};
const std::set<std::string> ir_builder_suppported_options = {
INPUT_FORMAT, INPUT_SHAPE, DYNAMIC_BATCH_SIZE, DYNAMIC_IMAGE_SIZE,
INSERT_OP_FILE, OUTPUT_TYPE, BUFFER_OPTIMIZE, ENABLE_COMPRESS_WEIGHT,
COMPRESS_WEIGHT_CONF, OUT_NODES, INPUT_FP16_NODES, LOG_LEVEL};
// for interface: aclgrphBuildInitialize
const std::set<std::string> global_options = {
HEAD_STREAM, CORE_TYPE, SOC_VERSION, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY,
AUTO_TUNE_MODE, ENABLE_SINGLE_STREAM};
const std::set<std::string> global_options = {HEAD_STREAM,
CORE_TYPE,
SOC_VERSION,
PRECISION_MODE,
EXEC_DISABLE_REUSED_MEMORY,
AUTO_TUNE_MODE,
ENABLE_SINGLE_STREAM,
AICORE_NUM,
FUSION_SWITCH_FILE,
ENABLE_SMALL_CHANNEL,
QUANT_OPTIMIZE,
OP_SELECT_IMPL_MODE};
} // namespace ir_option
} // namespace ge

@ -48,12 +48,9 @@ class NamedAttrs;
class Graph;
class AttrValue;
using SubgraphBuilder = std::function<Graph(const std::string &name)>;
using SubgraphBuilder = std::function<Graph()>;
using OperatorImplPtr = std::shared_ptr<OperatorImpl>;
class Graph;
using GraphBuilderCallback = std::function<Graph()>;
class OpIO;
using OutHandler = std::shared_ptr<OpIO>;
using InHandler = std::shared_ptr<OpIO>;
@ -139,12 +136,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator {
void SetInferenceContext(const InferenceContextPtr &inference_context);
InferenceContextPtr GetInferenceContext() const;
void SetGraphBuilder(const GraphBuilderCallback &builder);
graphStatus GetGraphBuilder(GraphBuilderCallback &builder) const;
void AddSubgraphName(const string &name);
string GetSubgraphName(int index) const;
graphStatus VerifyAllAttr(bool disable_common_verifier = false);
size_t GetInputsSize() const;
@ -265,9 +256,9 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator {
Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name);
void SubgraphRegister(const std::string &name, bool dynamic);
void SubgraphCountRegister(const std::string &name, uint32_t count);
void SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder);
void SubgraphRegister(const string &ir_name, bool dynamic);
void SubgraphCountRegister(const string &ir_name, uint32_t count);
void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder);
private:
Operator &SetInput(const string &dst_name, const OutHandler &out_handler);

@ -186,56 +186,54 @@ class OpReg {
Operator::OutputRegister(#x); \
(void)OpReg()
#define DYNAMIC_INPUT(x, t) \
N(); \
__dy_input_##x(); \
} \
\
public: \
_THIS_TYPE &create_dynamic_input_##x(unsigned int num, bool isPushBack = true) { \
Operator::DynamicInputRegister(#x, num, isPushBack); \
return *this; \
} \
_THIS_TYPE &create_dynamic_input_byindex_##x(unsigned int num, size_t index) { \
Operator::DynamicInputRegisterByIndex(#x, num, index); \
return *this; \
} \
TensorDesc get_dynamic_input_desc_##x(unsigned int index) const { return Operator::GetDynamicInputDesc(#x, index); } \
graphStatus update_dynamic_input_desc_##x(unsigned int index, const TensorDesc &tensorDesc) { \
return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \
} \
_THIS_TYPE &set_dynamic_input_##x(unsigned int dstIndex, Operator &v) { \
Operator::SetInput(#x, dstIndex, v); \
return *this; \
} \
_THIS_TYPE &set_dynamic_input_##x(unsigned int dstIndex, Operator &v, const string &srcName) { \
Operator::SetInput(#x, dstIndex, v, srcName); \
return *this; \
} \
\
private: \
void __dy_input_##x() { \
#define DYNAMIC_INPUT(x, t) \
N(); \
__dy_input_##x(); \
} \
\
public: \
_THIS_TYPE &create_dynamic_input_##x(uint32_t num, bool isPushBack = true) { \
Operator::DynamicInputRegister(#x, num, isPushBack); \
return *this; \
} \
_THIS_TYPE &create_dynamic_input_byindex_##x(uint32_t num, size_t index) { \
Operator::DynamicInputRegisterByIndex(#x, num, index); \
return *this; \
} \
TensorDesc get_dynamic_input_desc_##x(uint32_t index) const { return Operator::GetDynamicInputDesc(#x, index); } \
graphStatus update_dynamic_input_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \
return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \
} \
_THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v) { \
Operator::SetInput(#x, dstIndex, v); \
return *this; \
} \
_THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const string &srcName) { \
Operator::SetInput(#x, dstIndex, v, srcName); \
return *this; \
} \
\
private: \
void __dy_input_##x() { \
(void)OpReg()
#define DYNAMIC_OUTPUT(x, t) \
N(); \
__dy_output_##x(); \
} \
\
public: \
_THIS_TYPE &create_dynamic_output_##x(unsigned int num, bool isPushBack = true) { \
Operator::DynamicOutputRegister(#x, num, isPushBack); \
return *this; \
} \
TensorDesc get_dynamic_output_desc_##x(unsigned int index) const { \
return Operator::GetDynamicOutputDesc(#x, index); \
} \
graphStatus update_dynamic_output_desc_##x(unsigned int index, const TensorDesc &tensorDesc) { \
return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \
} \
\
private: \
void __dy_output_##x() { \
#define DYNAMIC_OUTPUT(x, t) \
N(); \
__dy_output_##x(); \
} \
\
public: \
_THIS_TYPE &create_dynamic_output_##x(uint32_t num, bool isPushBack = true) { \
Operator::DynamicOutputRegister(#x, num, isPushBack); \
return *this; \
} \
TensorDesc get_dynamic_output_desc_##x(uint32_t index) const { return Operator::GetDynamicOutputDesc(#x, index); } \
graphStatus update_dynamic_output_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \
return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \
} \
\
private: \
void __dy_output_##x() { \
(void)OpReg()
#define GRAPH(x) \
@ -258,29 +256,29 @@ class OpReg {
Operator::SubgraphCountRegister(#x, 1); \
(void)OpReg()
#define DYNAMIC_GRAPH(x) \
N(); \
__graph_##x(); \
} \
\
public: \
static const string name_graph_##x() { return #x; } \
_THIS_TYPE &create_dynamic_subgraph_##x(unsigned int num) { \
Operator::SubgraphCountRegister(#x, num); \
return *this; \
} \
SubgraphBuilder get_dynamic_subgraph_builder_##x(unsigned int index) const { \
return Operator::GetDynamicSubgraphBuilder(#x, index); \
} \
Graph get_dynamic_subgraph_##x(unsigned int index) const { return Operator::GetDynamicSubgraph(#x, index); } \
_THIS_TYPE &set_dynamic_subgraph_builder_##x(unsigned int index, const SubgraphBuilder &v) { \
Operator::SetSubgraphBuilder(#x, index, v); \
return *this; \
} \
\
private: \
void __graph_##x() { \
Operator::SubgraphRegister(#x, true); \
#define DYNAMIC_GRAPH(x) \
N(); \
__graph_##x(); \
} \
\
public: \
static const string name_graph_##x() { return #x; } \
_THIS_TYPE &create_dynamic_subgraph_##x(uint32_t num) { \
Operator::SubgraphCountRegister(#x, num); \
return *this; \
} \
SubgraphBuilder get_dynamic_subgraph_builder_##x(uint32_t index) const { \
return Operator::GetDynamicSubgraphBuilder(#x, index); \
} \
Graph get_dynamic_subgraph_##x(uint32_t index) const { return Operator::GetDynamicSubgraph(#x, index); } \
_THIS_TYPE &set_dynamic_subgraph_builder_##x(uint32_t index, const SubgraphBuilder &v) { \
Operator::SetSubgraphBuilder(#x, index, v); \
return *this; \
} \
\
private: \
void __graph_##x() { \
Operator::SubgraphRegister(#x, true); \
(void)OpReg()
#define PASTE(g_register, y) g_register##y

@ -24,7 +24,7 @@
namespace ge {
static const int64_t UNKNOWN_DIM = -1;
static const int64_t UNKNOWN_DIM_NUM = -2;
static const std::vector<int64_t> UNKNOWN_SHAPE = {0};
static const std::vector<int64_t> UNKNOWN_SHAPE = {-1};
static const std::vector<int64_t> UNKNOWN_RANK = {-2};
#ifdef HOST_VISIBILITY

@ -40,6 +40,14 @@ enum FrameworkType {
FMK_TYPE_RESERVED,
};
enum OpEngineType {
ENGINE_SYS = 0, // default engine
ENGINE_AICORE = 1,
ENGINE_VECTOR = 2,
ENGINE_AICUBE = 3, // not support
ENGINE_AIVECTOR = 4 // not support
};
const char *const GE_ENGINE_ATTR_MEM_TYPE_HBM = "HBM";
// Data cache, including data address and length
@ -141,6 +149,7 @@ struct Options {
int32_t device_id;
std::string job_id;
bool isUseHcom;
bool isUseHvd;
bool deployMode;
bool isAICPUMode;
bool enable_atomic;

@ -442,6 +442,7 @@ REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge");
REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph");
REGISTER_OPTYPE_DECLARE(SEND, "Send");
REGISTER_OPTYPE_DECLARE(RECV, "Recv");
REGISTER_OPTYPE_DECLARE(ENDOFSEQUENCE, "EndOfSequence");
REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet");
REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto");
@ -508,6 +509,12 @@ REGISTER_OPTYPE_DECLARE(DEPTHWISEWEIGHT6D24D, "depthwise_weight_6d_2_4d");
REGISTER_OPTYPE_DECLARE(SQRTGRAD, "SqrtGrad");
REGISTER_OPTYPE_DECLARE(SIGMOIDGRAD, "SigmoidGrad");
// Horovod operator
REGISTER_OPTYPE_DECLARE(HVDCALLBACKALLREDUCE, "HorovodAllreduce");
REGISTER_OPTYPE_DECLARE(HVDCALLBACKALLGATHER, "HorovodAllgather");
REGISTER_OPTYPE_DECLARE(HVDCALLBACKBROADCAST, "HorovodBroadcast");
REGISTER_OPTYPE_DECLARE(HVDWAIT, "HorovodWait");
enum InputMode { INPUT = 0, CONST };
// Definition of the processing status enum of the process module

@ -1,61 +0,0 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_FRAMEWORK_DLOG_LOG_H_
#define INC_FRAMEWORK_DLOG_LOG_H_
#include <string>
#if !defined(__ANDROID__) && !defined(ANDROID)
#include "toolchain/slog.h"
#else
#include <android/log.h>
#endif
#ifdef _MSC_VER
#define FUNC_NAME __FUNCTION__
#else
#define FUNC_NAME __PRETTY_FUNCTION__
#endif
#if !defined(__ANDROID__) && !defined(ANDROID)
#define DAV_LOGI(MOD_NAME, fmt, ...) dlog_info(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__)
#define DAV_LOGW(MOD_NAME, fmt, ...) dlog_warn(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__)
#define DAV_LOGE(MOD_NAME, fmt, ...) dlog_error(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__)
#define DAV_LOGD(MOD_NAME, fmt, ...) dlog_debug(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__)
#define DAV_EVENT(MOD_NAME, fmt, ...) dlog_event(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__)
#else
#define DAV_LOGI(MOD_NAME, fmt, ...) \
__android_log_print(ANDROID_LOG_INFO, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__)
#define DAV_LOGW(MOD_NAME, fmt, ...) \
__android_log_print(ANDROID_LOG_WARN, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__)
#define DAV_LOGE(MOD_NAME, fmt, ...) \
__android_log_print(ANDROID_LOG_ERROR, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__)
#define DAV_LOGD(MOD_NAME, fmt, ...) \
__android_log_print(ANDROID_LOG_DEBUG, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__)
#define DAV_EVENT(MOD_NAME, fmt, ...) \
__android_log_print(ANDROID_LOG_DEBUG, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__)
#endif
#define DLOG_DECLARE(level) \
void Log_##level(const char *mod_name, const char *func, const char *file, int line, const char *format, ...)
namespace domi {
DLOG_DECLARE(INFO);
DLOG_DECLARE(WARNING);
DLOG_DECLARE(ERROR);
} // namespace domi
#endif // INC_FRAMEWORK_DLOG_LOG_H_

@ -0,0 +1,113 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_
#define INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_
#include <memory>
#include <vector>
#include "ge_runtime/op_info.h"
#include "ge_runtime/task_info.h"
namespace ge {
namespace model_runner {
class DavinciModel {
public:
DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list,
const std::vector<std::shared_ptr<OpInfo>> &data_info_list,
const std::vector<std::shared_ptr<OpInfo>> &output_info_list,
const std::vector<std::shared_ptr<OpInfo>> &constant_info_list,
const std::vector<model_runner::OpInfoPtr> &variable_info_list,
const std::vector<uint32_t> &wait_active_stream_list,
const std::vector<uint32_t> &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0,
uint64_t var_size = 0, uintptr_t logic_mem_base = 0, uintptr_t logic_weight_base = 0,
uintptr_t logic_var_base = 0, uint32_t stream_num = 0, uint32_t batch_num = 0, uint32_t event_num = 0,
int32_t priority = 0)
: task_info_list_(task_info_list),
data_info_list_(data_info_list),
output_info_list_(output_info_list),
constant_info_list_(constant_info_list),
variable_info_list_(variable_info_list),
wait_active_stream_list_(wait_active_stream_list),
force_copy_stream_list_(force_copy_stream_list),
mem_size_(mem_size),
weight_size_(weight_size),
var_size_(var_size),
logic_mem_base_(logic_mem_base),
logic_weight_base_(logic_weight_base),
logic_var_base_(logic_var_base),
stream_num_(stream_num),
batch_num_(batch_num),
event_num_(event_num),
priority_(priority) {}
~DavinciModel() {}
uint64_t GetMemSize() const { return mem_size_; }
uint64_t GetWeightSize() const { return weight_size_; }
uint64_t GetVarSize() const { return var_size_; }
uintptr_t GetLogicMemBase() const { return logic_mem_base_; }
uintptr_t GetLogicWeightBase() const { return logic_weight_base_; }
uintptr_t GetLogicVarBase() const { return logic_var_base_; }
uint32_t GetStreamNum() const { return stream_num_; }
uint32_t GetBatchNum() const { return batch_num_; }
uint32_t GetEventNum() const { return event_num_; }
const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; }
const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; }
int32_t GetPriority() const { return priority_; }
const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; }
const std::vector<std::shared_ptr<OpInfo>> &GetDataInfoList() const { return data_info_list_; }
const std::vector<std::shared_ptr<OpInfo>> &GetOutputInfoList() const { return output_info_list_; }
const std::vector<std::shared_ptr<OpInfo>> &GetConstantInfoList() const { return output_info_list_; }
const std::vector<model_runner::OpInfoPtr> &GetVariableInfoList() const { return variable_info_list_; }
private:
std::vector<std::shared_ptr<TaskInfo>> task_info_list_;
std::vector<std::shared_ptr<OpInfo>> data_info_list_;
std::vector<std::shared_ptr<OpInfo>> output_info_list_;
std::vector<std::shared_ptr<OpInfo>> constant_info_list_;
std::vector<model_runner::OpInfoPtr> variable_info_list_;
std::vector<uint32_t> wait_active_stream_list_;
std::vector<uint32_t> force_copy_stream_list_;
uint64_t mem_size_;
uint64_t weight_size_;
uint64_t var_size_;
uintptr_t logic_mem_base_;
uintptr_t logic_weight_base_;
uintptr_t logic_var_base_;
uint32_t stream_num_;
uint32_t batch_num_;
uint32_t event_num_;
int32_t priority_;
// Disable to copy constructor and assignment operator
DavinciModel &operator=(const DavinciModel &) = delete;
DavinciModel(const DavinciModel &) = delete;
};
} // namespace model_runner
} // namespace ge
#endif // INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_

@ -0,0 +1,58 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_
#define INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include "common/ge_inner_error_codes.h"
#include "common/ge_types.h"
#include "ge_runtime/davinci_model.h"
namespace ge {
namespace model_runner {
class RuntimeModel;
class ModelRunner {
public:
static ModelRunner &Instance();
bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id,
std::shared_ptr<DavinciModel> davinci_model, std::shared_ptr<ModelListener> listener);
const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const;
bool UnloadModel(uint32_t model_id);
bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data);
bool GetInputOutputDescInfo(uint32_t model_id, bool zero_copy, std::vector<InputOutputDescInfo> *input_desc,
std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format,
std::vector<uint32_t> *output_format);
private:
ModelRunner() = default;
~ModelRunner() = default;
std::unordered_map<uint32_t, std::shared_ptr<RuntimeModel>> runtime_models_;
};
} // namespace model_runner
} // namespace ge
#endif // INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_

@ -0,0 +1,72 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_
#define INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_
#include <memory>
#include <string>
#include <vector>
namespace ge {
namespace model_runner {
struct TensorInfo {
int64_t GetShapeSize() const {
int64_t res = 1;
if (dims.empty()) {
return 0;
}
for (auto dim : dims) {
res *= dim;
}
return res;
}
int64_t GetDim(uint32_t index) {
if (index >= dims.size()) {
return 0;
}
return dims[index];
}
std::vector<int64_t> dims;
uint32_t datatype;
uint32_t format;
uint32_t real_dim_cnt;
uint32_t size;
bool is_output;
};
struct OpInfo {
uint32_t index;
std::string name;
std::string type;
bool var_is_broadcast;
std::vector<uintptr_t> input_addrs;
std::vector<uintptr_t> output_addrs;
std::vector<TensorInfo> input_tensors;
std::vector<TensorInfo> output_tensors;
std::vector<TensorInfo> weight_tensors;
std::vector<std::string> src_name;
std::vector<int64_t> src_index;
std::string weight_data;
};
using TensorInfoPtr = std::shared_ptr<TensorInfo>;
using OpInfoPtr = std::shared_ptr<OpInfo>;
} // namespace model_runner
} // namespace ge
#endif // INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_

File diff suppressed because it is too large Load Diff

@ -23,6 +23,7 @@
#include <vector>
#include "ge/ge_ir_build.h"
#include "common/ge_inner_error_codes.h"
#include "common/ge_types.h"
#include "graph/ge_tensor.h"
#include "graph/graph.h"
#include "graph/op_desc.h"
@ -30,9 +31,13 @@
namespace ge {
class GeGenerator {
public:
static GeGenerator &GetInstance() {
static GeGenerator Instance;
return Instance;
}
GeGenerator() = default;
~GeGenerator() = default;
~GeGenerator() { (void)Finalize(); }
GeGenerator(const GeGenerator &) = delete;
@ -60,10 +65,25 @@ class GeGenerator {
///
Status BuildSingleOpModel(OpDescPtr &op_desc, const std::vector<GeTensor> &inputs,
const std::vector<GeTensor> &outputs, const std::string &model_file_name);
///
/// @ingroup ge
/// @brief: Build single Op into model buff.
/// @param [in] op_desc: the OP description.
/// @param [in] inputs: input tensors.
/// @param [in] outputs: output tensors.
/// @param [in] engine_type: specific engine.
/// @param [out] model_buff: model buff of single op.
/// @return SUCCESS or FAILED
Status BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
OpEngineType engine_type, ModelBufferData &model_buff);
private:
Status GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs,
ge::ModelBufferData &model, bool is_offline = true);
Status BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs,
const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff,
bool is_offline = true);
class Impl;
std::shared_ptr<Impl> impl_;

@ -0,0 +1,113 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_FRAMEWORK_OMG_OMG_H_
#define INC_FRAMEWORK_OMG_OMG_H_
#include <google/protobuf/message.h>
#include <string>
#include <unordered_map>
#include <vector>
#include "framework/common/types.h"
#include "framework/omg/omg_inner_types.h"
#include "proto/ge_ir.pb.h"
#include "proto/om.pb.h"
#include "graph/compute_graph.h"
#include "graph/graph.h"
#include "graph/model.h"
#include "runtime/kernel.h"
using domi::Status;
using std::pair;
using std::string;
using std::unordered_map;
using std::vector;
namespace ge {
/**
* @ingroup domi_omg
* @brief init omg context
* @return void
*/
Status InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format,
bool is_dynamic_input);
/**
* @ingroup domi_omg
* @brief generate graph based on the input model file and weight file
* @param [out] graph graph
* @param [in] model_file path of model file
* @param [in] weights_file path of weight file
* @param [in] type type of the input model
* @param [in] op_conf op mapping configuration
* @param [in] target type of platform. If a tiny model is generated, set target to tiny
* @param [in] run_mode run model
* @param [in] enable_l2dynamic enable l2dynamic
* @param [in] is_dynamic_input dynamic input, true of false
* @param [in] atc_params multiply atc params
* @return Status result code
*/
Status ParseGraph(ge::Graph &graph, const std::map<string, string> &atc_params, const char *model_file,
const char *weights_file, domi::FrameworkType type, const char *op_conf = nullptr,
const char *target = nullptr, RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false);
/**
* @ingroup domi_omg
* @brief generates a simplified JSON file based on the key value of the offline model file in protobuf format
* @param [in] model_file path of offline model file
* @param [out] json_file path of json file
* @param [key] encrypted key
* @return Status result code
*/
Status ConvertOmModelToJson(const char *model_file, const char *json_file);
Status ConvertPbtxtToJson(const char *model_file, const char *json_file);
/**
* @ingroup domi_omg
* @brief convert the model file in protobuf format into a JSON file.
* @param [in] framework type of model
* @param [in] om model_file path of offline model file
* @param [out] json_file path of json file
* @param [key] encrypted key
* @return Status result code
*/
Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file, const char *json_file);
void GetGroupName(ge::proto::ModelDef &model);
void FindParserSo(const string &path, vector<string> &fileList, string &caffe_parser_path);
Status CheckCustomAiCpuOpLib();
Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file);
Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format);
Status GetOutputLeaf(ge::NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
std::vector<std::string> &output_nodes_name);
} // namespace ge
namespace domi {
/**
* @ingroup domi_omg
* @brief get omg context
* @return reference of OmgContext
*/
ge::OmgContext &GetContext();
} // namespace domi
#endif // INC_FRAMEWORK_OMG_OMG_H_

@ -44,11 +44,10 @@ namespace ge {
* @brief run model
*/
enum RunMode {
GEN_OM_MODEL = 0, // generate offline model file
MODEL_TO_JSON = 1, // convert to JSON file
MODEL_TO_JSON_WITH_SHAPE = 2, // convert to json file with shape
ONLY_PRE_CHECK = 3, // only for pre-check
PBTXT_TO_JSON = 5 // pbtxt to json
GEN_OM_MODEL = 0, // generate offline model file
MODEL_TO_JSON = 1, // convert to JSON file
ONLY_PRE_CHECK = 3, // only for pre-check
PBTXT_TO_JSON = 5 // pbtxt to json
};
///
@ -93,6 +92,8 @@ struct OmgContext {
std::map<std::string, std::vector<int32_t>> out_nodes_map;
// user-designate out nodes (this is used for determing the orders)
std::vector<std::pair<std::string, int32_t>> user_out_nodes;
// net out nodes (where user_out_nodes or leaf nodes)
std::vector<std::string> net_out_nodes;
// path for the aicpu custom operator so_file
std::vector<std::string> aicpu_op_run_paths;
// ddk version

@ -235,6 +235,8 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A
std::vector<NodePtr> &stack);
graphStatus BFSTopologicalSorting(std::vector<NodePtr> &node_vec, std::map<NodePtr, uint32_t> &map_in_edge_num,
std::deque<NodePtr> &stack);
graphStatus BFSTopologicalSortingWithGroup(std::vector<NodePtr> &node_vec,
std::map<NodePtr, uint32_t> &map_in_edge_num, std::deque<NodePtr> &stack);
graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num,
std::map<string, NodePtr> &breadth_node_map);
graphStatus TopologicalSortingGraph();

@ -94,6 +94,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STORAGE_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STORAGE_SHAPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILTER_FORMAT;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_K;
@ -133,6 +137,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_GRAPH_NAME;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE;
@ -692,6 +697,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OUT_NODES_NAME;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR;
@ -920,6 +927,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG;
@ -999,14 +1007,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION;
// functional ops attr
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_THEN_BRANCH;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_ELSE_BRANCH;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY;
// used for label switch
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_END_NODE;
// Varible
// Variable
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX;
@ -1032,6 +1043,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM
// Dynamic stitch
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM;
// Used for support Horovod
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INTER_EVENT_IDENTIFY;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE;
// for gradient group
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_GROUP;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_FLAG;
} // namespace ge
#endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_

@ -264,6 +264,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder {
graphStatus SetSubgraphInstanceName(uint32_t index, const std::string &name);
void RemoveSubgraphInstanceName(const std::string &name);
graphStatus GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const;
protected:
ProtoAttrMapHelper MutableAttrMap() override;
ConstProtoAttrMapHelper GetAttrMap() const override;
@ -288,7 +290,7 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder {
// subgraph ir names to type, for a `if` operator:
// then_branch: static
// else_branch: dynamic
// else_branch: static
// or for a `case` op:
// branches: dynamic
std::map<std::string, SubgraphType> subgraph_ir_names_to_type_;

@ -0,0 +1,46 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_
#define INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_
#include <map>
#include <memory>
#include <mutex>
#include <vector>
#include "external/graph/ge_error_codes.h"
#include "external/graph/tensor.h"
namespace ge {
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY RuntimeInferenceContext {
public:
static graphStatus GetContext(const std::string &context_id, RuntimeInferenceContext **ctx);
static graphStatus CreateContext(const std::string &context_id);
static void DestroyContext(const std::string &context_id);
graphStatus SetTensor(int64_t node_id, int output_id, Tensor &&tensor);
graphStatus GetTensor(int64_t node_id, int output_id, Tensor &tensor);
private:
std::map<int64_t, std::vector<Tensor>> tensors_;
std::mutex mu_;
static std::map<std::string, std::unique_ptr<RuntimeInferenceContext>> contexts_;
static std::mutex ctx_mu_;
};
} // namespace ge
#endif // INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_

@ -29,6 +29,18 @@
#include "graph/graph.h"
#include "graph/model.h"
#define GE_DUMP(compute_graph, name) \
do { \
GraphUtils::DumpGEGraph(compute_graph, name); \
GraphUtils::DumpGEGraphToOnnx(*compute_graph, name); \
for (const auto &sub_graph_func : compute_graph->GetAllSubgraphs()) { \
static int8_t i = 0; \
auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); \
GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name); \
GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name); \
} \
} while (0)
#define REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \
do { \
DataType ret; \
@ -155,6 +167,8 @@ class GraphUtils {
static graphStatus InsertNodeBetweenDataAnchors(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst,
const NodePtr &new_node);
static graphStatus RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node);
static graphStatus RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node);
static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor,
@ -299,6 +313,24 @@ class GraphUtils {
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors,
std::map<std::string, std::string> &anchor_to_symbol);
///
/// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs
/// of the graph have UNKNOWN_SHAPE operators or not.
/// Note: This function will only look 'down' from the graph, not 'up'. For example, the following
/// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE
/// ROOT graph: A -----> B -----> C
/// K subgraph U
/// |
/// V
/// SUB graph: D --> E --> F
/// K K K
/// @param [in] graph
/// @return bool
///
static bool IsUnknownShapeGraph(const ComputeGraphPtr &graph);
static NodePtr FindNodeFromAllNodes(ComputeGraphPtr &graph, const std::string &name);
private:
///
/// Get reference-mapping for in_data_anchors of node
@ -438,6 +470,11 @@ class ComputeGraphBuilder {
///
NodePtr GetNode(const std::string &name);
/// @brief Get all nodes
/// @return std::vector<NodePtr>
///
std::vector<NodePtr> GetAllNodes();
protected:
///
/// @brief Build nodes
@ -535,6 +572,13 @@ class CompleteGraphBuilder : public ComputeGraphBuilder {
///
CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind);
///
/// @brief Add target for graph
/// @param [in] target_name
/// @return CompleteGraphBuilder
///
CompleteGraphBuilder &AddTarget(const std::string &target_name);
///
/// @brief Set parent-node of graph
/// @param [in] parent_node
@ -590,10 +634,19 @@ class CompleteGraphBuilder : public ComputeGraphBuilder {
///
void AddRetValNodes(graphStatus &error_code, std::string &error_msg);
///
/// @brief Build target-nodes for graph
/// @param [out] error_code
/// @param [out] error_msg
/// @return void
///
void BuildGraphTargets(graphStatus &error_code, std::string &error_msg);
std::string name_;
NodePtr parent_node_;
std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_;
std::vector<std::pair<std::string, uint32_t>> graph_outputs_;
std::vector<std::string> graph_targets_;
// index_of_graph_input -> in_anchor_index_of_parent_node
std::map<uint32_t, uint32_t> input_mapping_;

@ -17,10 +17,23 @@
#ifndef INC_GRAPH_UTILS_NODE_UTILS_H_
#define INC_GRAPH_UTILS_NODE_UTILS_H_
#include <set>
#include <map>
#include <vector>
#include "graph/node.h"
namespace ge {
// Op types of Const like Opps.
extern const std::set<std::string> kConstOpTypes;
// Op types of If like Opps.
extern const std::set<std::string> kIfOpTypes;
// Op types of While like Opps.
extern const std::set<std::string> kWhileOpTypes;
// Op types of Case like Opps.
extern const std::set<std::string> kCaseOpTypes;
// Op types of For like Opps.
extern const std::set<std::string> kForOpTypes;
class NodeUtils {
public:
static graphStatus AddSendEventId(const NodePtr &node, const uint32_t &event_id);
@ -94,6 +107,13 @@ class NodeUtils {
///
static bool GetConstOpType(const NodePtr &in_node, std::string &op_type);
///
/// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph.
/// @param [in] node
/// @return return GRAPH_SUCCESS if remove successfully, other for failed.
///
static graphStatus RemoveSubgraphsOnNode(const NodePtr &node);
private:
static std::map<NodePtr, std::vector<uint32_t>> map_send_info_;
static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_;

@ -24,6 +24,7 @@
#include "graph/ge_error_codes.h"
#include "graph/types.h"
#include "graph/usr_types.h"
#include "register/register_types.h"
namespace ge {
class TypeUtils {
@ -37,6 +38,7 @@ class TypeUtils {
static std::string FormatToSerialString(Format format);
static Format SerialStringToFormat(const std::string &str);
static Format DataFormatToFormat(const std::string &str);
static Format DomiFormatToFormat(domi::domiTensorFormat_t domi_format);
static graphStatus Usr2DefQuantizeFactorParams(const UsrQuantizeFactorParams &usr, QuantizeFactorParams &def);
static graphStatus Def2UsrQuantizeFactorParams(const QuantizeFactorParams &def, UsrQuantizeFactorParams &usr);

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save