!335 Synchronize latest Ascend software suite 19 Nov 2020

From: @nicholas_yhr
Reviewed-by: @youui,@liujunzhu
Signed-off-by: @liujunzhu
pull/335/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 9153665631

@ -59,6 +59,25 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session {
/// ///
Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<std::string, std::string> &options); Status AddGraph(uint32_t graphId, const Graph &graph, const std::map<std::string, std::string> &options);
///
/// @ingroup client
/// @brief add a copy graph with a specific graphId
/// @param [in] graphId graph id
/// @param [in] graph the graph
/// @return Status result of function
///
Status AddGraphWithCopy(uint32_t graph_id, const Graph &graph);
///
/// @ingroup client
/// @brief add a copy graph with a specific graphId and graphOptions
/// @param [in] graphId graph id
/// @param [in] graph the graph
/// @param [in] options graph options
/// @return Status result of function
///
Status AddGraphWithCopy(uint32_t graph_id, const Graph &graph, const std::map<AscendString, AscendString> &options);
/// ///
/// @ingroup ge_graph /// @ingroup ge_graph
/// @brief remove a graph of the session with specific session id /// @brief remove a graph of the session with specific session id

@ -245,6 +245,12 @@ const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16";
// 0: close debug; 1: open TBE compiler; 2: open ccec compiler // 0: close debug; 1: open TBE compiler; 2: open ccec compiler
const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel";
// Configure model bank path
const std::string MDL_BANK_PATH_FLAG = "ge.mdl_bank_path";
// Configure op bank path
const std::string OP_BANK_PATH_FLAG = "ge.op_bank_path";
// Graph run mode // Graph run mode
enum GraphRunMode { PREDICTION = 0, TRAIN }; enum GraphRunMode { PREDICTION = 0, TRAIN };
@ -315,13 +321,28 @@ static const char *const OPTYPELIST_FOR_IMPLMODE = ge::OPTYPELIST_FOR_IMPLMODE.c
static const char *const DEBUG_DIR = ge::DEBUG_DIR; static const char *const DEBUG_DIR = ge::DEBUG_DIR;
static const char *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR; static const char *const OP_COMPILER_CACHE_DIR = ge::OP_COMPILER_CACHE_DIR;
static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE; static const char *const OP_COMPILER_CACHE_MODE = ge::OP_COMPILER_CACHE_MODE;
static const char *const MDL_BANK_PATH_FLAG = ge::MDL_BANK_PATH_FLAG.c_str();
static const char *const OP_BANK_PATH_FLAG = ge::OP_BANK_PATH_FLAG.c_str();
// for interface: aclgrphBuildModel // for interface: aclgrphBuildModel
const std::set<std::string> ir_builder_suppported_options = { const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT,
INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, INPUT_SHAPE,
DYNAMIC_BATCH_SIZE, DYNAMIC_IMAGE_SIZE, DYNAMIC_DIMS, OP_NAME_MAP,
INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, DYNAMIC_BATCH_SIZE,
AUTO_TUNE_MODE, OUTPUT_TYPE, OUT_NODES, DYNAMIC_IMAGE_SIZE,
INPUT_FP16_NODES, LOG_LEVEL}; DYNAMIC_DIMS,
INSERT_OP_FILE,
PRECISION_MODE,
EXEC_DISABLE_REUSED_MEMORY,
AUTO_TUNE_MODE,
OUTPUT_TYPE,
OUT_NODES,
INPUT_FP16_NODES,
LOG_LEVEL,
OP_DEBUG_LEVEL,
DEBUG_DIR,
OP_COMPILER_CACHE_DIR,
OP_COMPILER_CACHE_MODE};
// for interface: aclgrphParse // for interface: aclgrphParse
const std::set<std::string> ir_parser_suppported_options = {INPUT_FORMAT, const std::set<std::string> ir_parser_suppported_options = {INPUT_FORMAT,
@ -336,7 +357,9 @@ const std::set<std::string> ir_parser_suppported_options = {INPUT_FORMAT,
OUT_NODES, OUT_NODES,
COMPRESS_WEIGHT_CONF, COMPRESS_WEIGHT_CONF,
ENABLE_SCOPE_FUSION_PASSES, ENABLE_SCOPE_FUSION_PASSES,
LOG_LEVEL}; LOG_LEVEL,
MDL_BANK_PATH_FLAG,
OP_BANK_PATH_FLAG};
// for interface: aclgrphBuildInitialize // for interface: aclgrphBuildInitialize
const std::set<std::string> global_options = {CORE_TYPE, const std::set<std::string> global_options = {CORE_TYPE,

@ -31,6 +31,18 @@ class AscendString {
const char* GetString() const; const char* GetString() const;
bool operator<(const AscendString& d) const;
bool operator>(const AscendString& d) const;
bool operator<=(const AscendString& d) const;
bool operator>=(const AscendString& d) const;
bool operator==(const AscendString& d) const;
bool operator!=(const AscendString& d) const;
private: private:
std::shared_ptr<std::string> name_; std::shared_ptr<std::string> name_;
}; };

@ -94,6 +94,7 @@ using FusionParseParamFunc =
std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>; std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>;
using FusionParseParamByOpFunc = std::function<domi::Status(const std::vector<ge::Operator> &, ge::Operator &)>; using FusionParseParamByOpFunc = std::function<domi::Status(const std::vector<ge::Operator> &, ge::Operator &)>;
using ParseSubgraphFunc = std::function<Status(const std::string &subgraph_name, const ge::Graph &graph)>; using ParseSubgraphFunc = std::function<Status(const std::string &subgraph_name, const ge::Graph &graph)>;
using ParseOpToGraphFunc = std::function<Status(const ge::Operator &, ge::Graph &)>;
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData {
public: public:
@ -125,6 +126,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData {
OpRegistrationData &InputReorderVector(const vector<int> &input_order); OpRegistrationData &InputReorderVector(const vector<int> &input_order);
OpRegistrationData &ParseOpToGraphFn(const ParseOpToGraphFunc &parse_op_to_graph_fn);
domi::ImplyType GetImplyType() const; domi::ImplyType GetImplyType() const;
std::string GetOmOptype() const; std::string GetOmOptype() const;
std::set<std::string> GetOriginOpTypeSet() const; std::set<std::string> GetOriginOpTypeSet() const;
@ -134,6 +137,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData {
FusionParseParamFunc GetFusionParseParamFn() const; FusionParseParamFunc GetFusionParseParamFn() const;
FusionParseParamByOpFunc GetFusionParseParamByOpFn() const; FusionParseParamByOpFunc GetFusionParseParamByOpFn() const;
ParseSubgraphFunc GetParseSubgraphPostFn() const; ParseSubgraphFunc GetParseSubgraphPostFn() const;
ParseOpToGraphFunc GetParseOpToGraphFn() const;
private: private:
std::shared_ptr<OpRegistrationDataImpl> impl_; std::shared_ptr<OpRegistrationDataImpl> impl_;

@ -18,10 +18,12 @@
#define INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ #define INC_FRAMEWORK_COMMON_DEBUG_LOG_H_
#include <string> #include <string>
#include <sstream>
#include "runtime/rt.h" #include "runtime/rt.h"
#include "common/string_util.h" #include "common/string_util.h"
#include "common/util.h" #include "common/util.h"
#include "common/util/error_manager/error_manager.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "ge/ge_api_error_codes.h" #include "ge/ge_api_error_codes.h"
@ -253,4 +255,29 @@
exec_expr1; \ exec_expr1; \
} }
#define GE_ERRORLOG_AND_ERRORMSG(_status, errormsg) \
{ \
GELOGE(_status, "%s", errormsg); \
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {errormsg}); \
}
#define GE_CHK_LOG_AND_ERRORMSG(expr, _status, errormsg) \
do { \
bool b = (expr); \
if (!b) { \
GELOGE(_status, "%s", errormsg); \
ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {errormsg}); \
return _status; \
} \
} while (0)
template <typename T>
std::string FmtToStr(const T &t) {
std::string fmt;
std::stringstream st;
st << "[" << t << "]";
fmt = st.str();
return fmt;
}
#endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ #endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_

@ -70,6 +70,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFIL
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_STOP_VALUE; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_STOP_VALUE;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::map<std::string, std::string> PROFILE_COMPONENT_MAP; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::map<std::string, std::string> PROFILE_COMPONENT_MAP;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_CONFIG; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_CONFIG;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_MODEL_ID;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASKS; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASKS;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR;

@ -270,6 +270,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor {
static ge::Status ReleaseSingleOpResource(void *stream); static ge::Status ReleaseSingleOpResource(void *stream);
static ge::Status GetDeviceIdByModelId(uint32_t model_id, uint32_t &device_id);
ge::Status GetBatchInfoSize(uint32_t model_id, size_t &shape_count); ge::Status GetBatchInfoSize(uint32_t model_id, size_t &shape_count);
ge::Status GetOrigInputInfo(uint32_t model_id, uint32_t index, OriginInputInfo &orig_input_info); ge::Status GetOrigInputInfo(uint32_t model_id, uint32_t index, OriginInputInfo &orig_input_info);
ge::Status GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, std::vector<InputOutputDims> &input_dims, ge::Status GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, std::vector<InputOutputDims> &input_dims,

@ -1115,6 +1115,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYN
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_DATATYPE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_DATATYPE;
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_FORMAT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_FORMAT;
// atc user def dtype&format
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_OUTPUT_NODES;
// for fusion op plugin // for fusion op plugin
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE;

@ -42,6 +42,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR}
"detail/*.cc" "detail/*.cc"
"debug/*.cc" "debug/*.cc"
"option/*.cc" "option/*.cc"
"transformer/src/*cc"
) )
# include directories # include directories

@ -30,4 +30,66 @@ const char* AscendString::GetString() const {
return (*name_).c_str(); return (*name_).c_str();
} }
bool AscendString::operator<(const AscendString& d) const {
if (name_ == nullptr && d.name_ == nullptr) {
return false;
} else if (name_ == nullptr) {
return true;
} else if (d.name_ == nullptr) {
return false;
}
return (*name_ < *(d.name_));
}
bool AscendString::operator>(const AscendString& d) const {
if (name_ == nullptr && d.name_ == nullptr) {
return false;
} else if (name_ == nullptr) {
return false;
} else if (d.name_ == nullptr) {
return true;
}
return (*name_ > *(d.name_));
}
bool AscendString::operator==(const AscendString& d) const {
if (name_ == nullptr && d.name_ == nullptr) {
return true;
} else if (name_ == nullptr) {
return false;
} else if (d.name_ == nullptr) {
return false;
}
return (*name_ == *(d.name_));
}
bool AscendString::operator<=(const AscendString& d) const {
if (name_ == nullptr) {
return true;
} else if (d.name_ == nullptr) {
return false;
}
return (*name_ <= *(d.name_));
}
bool AscendString::operator>=(const AscendString& d) const {
if (d.name_ == nullptr) {
return true;
} else if (name_ == nullptr) {
return false;
}
return (*name_ >= *(d.name_));
}
bool AscendString::operator!=(const AscendString& d) const {
if (name_ == nullptr && d.name_ == nullptr) {
return false;
} else if (name_ == nullptr) {
return true;
} else if (d.name_ == nullptr) {
return true;
}
return (*name_ != *(d.name_));
}
} // namespace ge } // namespace ge

@ -384,12 +384,15 @@ void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector<ge::NodePtr> &anchor
continue; continue;
} }
for (const auto &input_desc : node->GetOpDesc()->GetAllInputsDescPtr()) { for (const auto &input_desc : node->GetOpDesc()->GetAllInputsDescPtr()) {
if (input_desc != nullptr) { // single op support private format set, its origin format should not be override
auto ori_format = input_desc->GetOriginFormat();
if (input_desc != nullptr && (ori_format == FORMAT_ND || ori_format == FORMAT_RESERVED)) {
input_desc->SetOriginFormat(input_desc->GetFormat()); input_desc->SetOriginFormat(input_desc->GetFormat());
} }
} }
for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDescPtr()) { for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDescPtr()) {
if (output_desc != nullptr) { auto ori_format = output_desc->GetOriginFormat();
if (output_desc != nullptr && (ori_format == FORMAT_ND || ori_format == FORMAT_RESERVED)) {
output_desc->SetOriginFormat(output_desc->GetFormat()); output_desc->SetOriginFormat(output_desc->GetFormat());
} }
} }

@ -1078,6 +1078,9 @@ const std::string ATTR_NAME_DYNAMIC_INPUT_END = "_dynamic_input_index_end";
const std::string ATTR_ATC_USER_DEFINE_DATATYPE = "_user_defined_data_type"; const std::string ATTR_ATC_USER_DEFINE_DATATYPE = "_user_defined_data_type";
const std::string ATTR_ATC_USER_DEFINE_FORMAT = "_user_defined_format"; const std::string ATTR_ATC_USER_DEFINE_FORMAT = "_user_defined_format";
// atc user def dtype&format
const std::string ATTR_ATC_USER_DEFINE_OUTPUT_NODES = "_user_defined_output_nodes";
// for fusion op plugin // for fusion op plugin
const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type";

@ -46,6 +46,10 @@ COMMON_LOCAL_SRC_FILES := \
option/ge_local_context.cc \ option/ge_local_context.cc \
./runtime_inference_context.cc \ ./runtime_inference_context.cc \
./utils/node_utils.cc \ ./utils/node_utils.cc \
../third_party/transformer/src/axis_util.cpp \
../third_party/transformer/src/transfer_shape_according_to_format.cpp \
./utils/transformer_utils.cc \
COMMON_LOCAL_C_INCLUDES := \ COMMON_LOCAL_C_INCLUDES := \
proto/om.proto \ proto/om.proto \
@ -57,13 +61,19 @@ COMMON_LOCAL_C_INCLUDES := \
proto/op_mapping_info.proto \ proto/op_mapping_info.proto \
proto/dump_task.proto \ proto/dump_task.proto \
inc \ inc \
metadef/inc \
graphengine/inc \
inc/external \ inc/external \
inc/external/graph \ metadef/inc/external \
inc/graph \ graphengine/inc/external \
inc/common \ metadef/inc/external/graph \
common \ metadef/inc/graph \
common/graph \ metadef/inc/common \
metadef \
metadef/graph \
third_party/protobuf/include \ third_party/protobuf/include \
$(TOPDIR)metadef/third_party \
$(TOPDIR)metadef/third_party/transformer/inc \
libc_sec/include \ libc_sec/include \
ops/built-in/op_proto/inc \ ops/built-in/op_proto/inc \
cann/ops/built-in/op_proto/inc \ cann/ops/built-in/op_proto/inc \

@ -27,6 +27,7 @@
#include "graph/utils/attr_utils.h" #include "graph/utils/attr_utils.h"
#include "graph/utils/ge_ir_utils.h" #include "graph/utils/ge_ir_utils.h"
#include "graph/utils/op_desc_utils.h" #include "graph/utils/op_desc_utils.h"
#include "graph/utils/transformer_utils.h"
#include "proto/ge_ir.pb.h" #include "proto/ge_ir.pb.h"
using std::make_pair; using std::make_pair;
@ -1301,11 +1302,24 @@ graphStatus OpDesc::CallInferFunc(Operator &op) {
return GRAPH_PARAM_INVALID; return GRAPH_PARAM_INVALID;
} }
} }
std::unique_ptr<NodeShapeTransUtils> transformer(new (std::nothrow) NodeShapeTransUtils(shared_from_this()));
if (transformer == nullptr) {
GELOGE(GRAPH_FAILED, "Memory alloc failed");
return GRAPH_FAILED;
}
if (!transformer->CatchFormatAndShape()) {
GELOGE(GRAPH_FAILED, "catch format and shape info failed!");
return GRAPH_FAILED;
}
graphStatus graph_status = (graphStatus)infer_func_(op); graphStatus graph_status = (graphStatus)infer_func_(op);
if (graph_status != GRAPH_SUCCESS) { if (graph_status != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "%s call infer func. ret: %u", GetName().c_str(), graph_status); GELOGE(GRAPH_FAILED, "%s call infer func. ret: %u", GetName().c_str(), graph_status);
return GRAPH_FAILED; return GRAPH_FAILED;
} }
if (!transformer->UpdateFormatAndShape()) {
GELOGE(GRAPH_FAILED, "catch format and shape info failed!");
return GRAPH_FAILED;
}
return GRAPH_SUCCESS; return GRAPH_SUCCESS;
} }
graphStatus OpDesc::CallInferFormatFunc(Operator &op) { graphStatus OpDesc::CallInferFormatFunc(Operator &op) {

@ -1425,7 +1425,10 @@ class GraphBuilderImpl {
const string name = node->GetName(); const string name = node->GetName();
for (auto &name_idx : op_impl->op_desc_->GetSubgraphNameIndexes()) { for (auto &name_idx : op_impl->op_desc_->GetSubgraphNameIndexes()) {
const SubgraphBuilder &builder = op_impl->GetSubgraphBuilder(name_idx.first); const SubgraphBuilder &builder = op_impl->GetSubgraphBuilder(name_idx.first);
GE_CHK_BOOL_EXEC(builder != nullptr, return GRAPH_FAILED, "Node: %s, Get builder failed.", name.c_str()); if (builder == nullptr) {
GELOGW("Node: %s, Has no builder.", name.c_str());
continue;
}
Graph graph = builder(); // Build subgraph from user define builder. Graph graph = builder(); // Build subgraph from user define builder.
const ComputeGraphPtr &subgraph = GraphUtils::GetComputeGraph(graph); const ComputeGraphPtr &subgraph = GraphUtils::GetComputeGraph(graph);

@ -26,6 +26,7 @@
#include "debug/ge_log.h" #include "debug/ge_log.h"
#include "debug/ge_op_types.h" #include "debug/ge_op_types.h"
#include "debug/ge_util.h"
#include "external/graph/operator.h" #include "external/graph/operator.h"
#include "external/graph/operator_factory.h" #include "external/graph/operator_factory.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
@ -41,7 +42,6 @@ const uint32_t kWhileBodySubGraphIdx = 1;
graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) { graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) {
GELOGD("Enter reverse brush while body subgraph process!"); GELOGD("Enter reverse brush while body subgraph process!");
auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx); auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx);
if (sub_graph_body == nullptr) { if (sub_graph_body == nullptr) {
GELOGE(GRAPH_FAILED, "Get while body graph failed!"); GELOGE(GRAPH_FAILED, "Get while body graph failed!");
@ -661,10 +661,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh
if (!is_unknown_graph) { if (!is_unknown_graph) {
auto inference_context = CreateInferenceContext(context_map, node); auto inference_context = CreateInferenceContext(context_map, node);
if (inference_context == nullptr) { GE_CHECK_NOTNULL(inference_context);
GELOGE(GRAPH_FAILED, "inference context is null");
return GRAPH_FAILED;
}
GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size());
op.SetInferenceContext(inference_context); op.SetInferenceContext(inference_context);
} }
@ -678,8 +675,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh
auto op_desc = node->GetOpDesc(); auto op_desc = node->GetOpDesc();
for (const auto &out_anchor : node->GetAllOutDataAnchors()) { for (const auto &out_anchor : node->GetAllOutDataAnchors()) {
auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size())); if (output_tensor->MutableShape().GetDims().empty()) {
output_tensor->SetOriginShape(output_tensor->GetShape()); output_tensor->SetOriginShape(output_tensor->GetShape());
}
ge::TensorUtils::SetRealDimCnt(*output_tensor,
static_cast<uint32_t>(output_tensor->GetOriginShape().GetDims().size()));
output_tensor->SetOriginDataType(output_tensor->GetDataType()); output_tensor->SetOriginDataType(output_tensor->GetDataType());
GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",

@ -0,0 +1,144 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file axis_util.h
* \brief get the axis value
*/
#ifndef COMMON_UTILS_TRANSFER_AXIS_UTIL_H_
#define COMMON_UTILS_TRANSFER_AXIS_UTIL_H_
#include <memory.h>
#include <functional>
#include <vector>
#include "external/graph/ge_error_codes.h"
#include "external/graph/types.h"
#include "framework/common/debug/ge_log.h"
namespace common {
namespace transformer {
const int32_t DIM_DEFAULT_SIZE = 4;
const uint32_t NCHW_DIMENSION_NUM = 4;
const int32_t AXIS_NCHW_DIM_N = 0;
const int32_t AXIS_NCHW_DIM_C = 1;
const int32_t AXIS_NCHW_DIM_H = 2;
const int32_t AXIS_NCHW_DIM_W = 3;
const int32_t AXIS_NHWC_DIM_N = 0;
const int32_t AXIS_NHWC_DIM_H = 1;
const int32_t AXIS_NHWC_DIM_W = 2;
const int32_t AXIS_NHWC_DIM_C = 3;
const int32_t AXIS_NC1HWC0_DIM_N = 0;
const int32_t AXIS_NC1HWC0_DIM_C1 = 1;
const int32_t AXIS_NC1HWC0_DIM_C0 = 4;
const int32_t AXIS_NC1HWC0_DIM_H = 2;
const int32_t AXIS_NC1HWC0_DIM_W = 3;
const int32_t AXIS_HWCN_DIM_H = 0;
const int32_t AXIS_HWCN_DIM_W = 1;
const int32_t AXIS_HWCN_DIM_C = 2;
const int32_t AXIS_HWCN_DIM_N = 3;
const int32_t AXIS_C1HWNCoC0_DIM_C1 = 0;
const int32_t AXIS_C1HWNCoC0_DIM_H = 1;
const int32_t AXIS_C1HWNCoC0_DIM_W = 2;
const int32_t AXIS_C1HWNCoC0_DIM_N = 3;
const int32_t AXIS_C1HWNCoC0_DIM_Co = 4;
const int32_t AXIS_C1HWNCoC0_DIM_C0 = 5;
#define CHECK_NOTNULL(val) \
do { \
if ((val) == nullptr) { \
GELOGE(GRAPH_FAILED, "[ERROR]Parameter[%s] must not be null.", #val); \
return false; \
} \
} while (0)
#define CHECK(cond, log_func, return_expr) \
do { \
if (cond) { \
log_func; \
return_expr; \
} \
} while (0)
enum AxisValueType {
AXIS_N = 0,
AXIS_C = 1,
AXIS_H = 2,
AXIS_W = 3,
AXIS_C1 = 4,
AXIS_C0 = 5,
AXIS_Co = 6,
AXIS_D = 7,
AXIS_BOTTOM = 8
};
int64_t DivisionCeiling(int64_t dividend, int64_t divisor);
/* Axis value is arranged as {N,C,H,W,C1,C0,...} */
/* The first parameter is old shape's dimension,
* second is c0 and third is axis value. */
using GetAxisValueInfoByFormat =
std::function<bool(const std::vector<int64_t>&, const uint32_t&, std::vector<int64_t>&, std::vector<int64_t>&)>;
using GetAxisValueInfoByFormatPtr = std::shared_ptr<GetAxisValueInfoByFormat>;
class AxisUtil {
public:
AxisUtil();
~AxisUtil(){};
bool GetAxisValueByOriginFormat(const ge::Format& format, const std::vector<int64_t>& dimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
bool HasAxisValueFunc(const ge::Format& format);
private:
static bool CheckParams(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
static bool GetAxisValueByNCHW(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
static bool GetAxisValueByNHWC(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
static bool GetAxisValueByNC1HWC0(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
static bool GetAxisValueByFz(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
static bool GetAxisValueByHWCN(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
static bool GetAxisValueByND(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
static bool GetAxisValueByC1HWNCoC0(const std::vector<int64_t>& originalDimVec, const uint32_t& c0,
std::vector<int64_t>& axisValue, std::vector<int64_t>& ndValue);
/* map of GetAxisValueInfoByFormat, get axis value by different original
* formats. */
std::map<ge::Format, GetAxisValueInfoByFormatPtr> getAxisValueFuncMap;
};
} // namespace transformer
} // namespace common
#endif // COMMON_UTILS_TRANSFER_AXIS_UTIL_H_

@ -0,0 +1,122 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file transfer_shape_according_to_format.h
* \brief set shape according to original format and current format
*/
#ifndef COMMON_UTILS_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
#define COMMON_UTILS_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_
#include "transformer/inc/axis_util.h"
#include <memory.h>
#include <functional>
#include <vector>
#include "graph/types.h"
#include "graph/utils/op_desc_utils.h"
namespace common {
namespace transformer {
enum OpImplType {
EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op
EN_IMPL_CUSTOM_TIK, // custom tik op
EN_IMPL_CUSTOM_TBE, // custom tbe op
EN_IMPL_HW_CONSTANT_CCE, // Huawei built-in constant op
EN_IMPL_HW_GENERAL_CCE, // Huawei built-in cce op
EN_IMPL_HW_TIK, // Huawei built-in tik op
EN_IMPL_HW_TBE, // Huawei built-in tbe op
EN_IMPL_RL, // RL op
EN_IMPL_PLUGIN_TBE, // Huawei built-in tbe plugin op
EN_IMPL_VECTOR_CORE_HW_TBE, // Huawei built-in tbe op
EN_IMPL_VECTOR_CORE_CUSTOM_TBE, // custom tbe op
EN_IMPL_NON_PERSISTENT_CUSTOM_TBE, // custom tbe op
EN_RESERVED // reserved value
};
const uint32_t SHAPE_NUMBER_16 = 16;
const uint32_t SHAPE_NUMBER_32 = 32;
const uint32_t SHAPE_DIM_VALUE_C04 = 4;
const uint32_t NI = 16;
const uint32_t MINUS_VALUE_ONE = 1;
const uint32_t MINUS_VALUE_TWO = 2;
const uint32_t SIZE_OF_CN = 2;
const uint32_t MINIMUM_NZ_SHAPE_DIM_NUM = 2;
/* The first parameter is axis value, second is new shape and third is
* op implementation type. */
using GetNewShapeByAxisValueAndFormat =
std::function<bool(vector<int64_t> &, const int64_t &, vector<int64_t> &, vector<int64_t> &)>;
using GetNewShapeByAxisValueAndFormatPtr = std::shared_ptr<GetNewShapeByAxisValueAndFormat>;
struct ShapeAndFormatInfo {
const std::vector<int64_t> &oldShape;
std::vector<int64_t> &newShape;
const ge::Format &oldFormat;
const ge::Format &newFormat;
const ge::DataType &currentDataType;
const int64_t &opImplType;
};
using ShapeAndFormat = struct ShapeAndFormatInfo;
class ShapeTransferAccordingToFormat {
public:
ShapeTransferAccordingToFormat();
~ShapeTransferAccordingToFormat(){};
ShapeTransferAccordingToFormat(const ShapeTransferAccordingToFormat &) = delete;
ShapeTransferAccordingToFormat &operator=(const ShapeTransferAccordingToFormat &) = delete;
bool GetShapeAccordingToFormat(ShapeAndFormat &inputAndOutputInfo, int64_t *c = nullptr);
/* ----------Below is the function of getting new shape---------------------- */
static bool GetNCHWShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType,
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue);
static bool GetNHWCShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType,
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue);
static bool GetNC1HWC0ShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType,
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue);
static bool GetFzShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType,
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue);
static bool GetHWCNShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType,
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue);
static bool GetC1HWNCoC0ShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType,
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue);
static bool GetNzShapeByAxisValue(vector<int64_t> &newShape, const int64_t &implType,
const vector<int64_t> &axisValue, const vector<int64_t> &ndValue);
private:
/* map of GetAxisValueInfoByFormat, get axis value by different original
* formats. */
std::map<ge::Format, GetNewShapeByAxisValueAndFormatPtr> getNewShapeFuncMap;
std::map<ge::DataType, uint32_t> mapOfDtypeAndC0;
};
} // namespace transformer
} // namespace common
#endif // COMMON_UTILS_TRANSFER_SHAPE_ACCORDING_TO_FORMAT_H_

@ -0,0 +1,198 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file axis_util.cpp
* \brief get the axis value
*/
#include "transformer/inc/axis_util.h"
#include "graph/types.h"
namespace common {
namespace transformer {
using namespace ge;
using namespace std;
AxisUtil::AxisUtil() {
getAxisValueFuncMap = {{FORMAT_NCHW, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNCHW)},
{FORMAT_NHWC, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNHWC)},
{FORMAT_NC1HWC0, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByNC1HWC0)},
{FORMAT_HWCN, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByHWCN)},
{FORMAT_ND, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByND)},
{FORMAT_C1HWNCoC0, std::make_shared<GetAxisValueInfoByFormat>(GetAxisValueByC1HWNCoC0)}};
}
int64_t DivisionCeiling(int64_t dividend, int64_t divisor) {
if (divisor == 0) {
return 0;
} else {
return (dividend + divisor - 1) / divisor;
}
}
bool AxisUtil::GetAxisValueByOriginFormat(const Format &format, const vector<int64_t> &dimVec, const uint32_t &c0,
vector<int64_t> &axisValue, vector<int64_t> &ndValue) {
auto iterGetAxisFunc = getAxisValueFuncMap.find(format);
if (iterGetAxisFunc == getAxisValueFuncMap.end()) {
GELOGI("Can not get axis value of old format %u!", format);
return false;
}
GetAxisValueInfoByFormatPtr getAxisFunc = iterGetAxisFunc->second;
CHECK_NOTNULL(getAxisFunc);
return (*getAxisFunc)(dimVec, c0, axisValue, ndValue);
}
bool AxisUtil::HasAxisValueFunc(const Format &format) {
auto iterGetAxisFunc = getAxisValueFuncMap.find(format);
if (iterGetAxisFunc == getAxisValueFuncMap.end()) {
GELOGI("Can not get axis value of format %u!", format);
return false;
}
return true;
}
bool AxisUtil::CheckParams(const vector<int64_t> &originalDimVec, const uint32_t &c0, vector<int64_t> &axisValue,
vector<int64_t> &ndValue) {
ndValue = originalDimVec;
auto dimSize = originalDimVec.size();
if (dimSize < DIM_DEFAULT_SIZE) {
/* Before this funcion, we should call function PadDimensionTo4. */
GELOGI("Dimension size %zu is invalid.", dimSize);
return false;
}
if (c0 == 0) {
GELOGE(GRAPH_FAILED, "[ERROR]c0 is zero!");
return false;
}
return true;
}
bool AxisUtil::GetAxisValueByND(const vector<int64_t> &originalDimVec, const uint32_t &c0, vector<int64_t> &axisValue,
vector<int64_t> &ndValue) {
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true);
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true);
ndValue = originalDimVec;
/* To differentiate the input datatype of int8 and others */
axisValue[AXIS_C0] = c0;
if (originalDimVec.size() == NCHW_DIMENSION_NUM) {
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N];
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C];
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H];
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W];
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0);
axisValue[AXIS_Co] = c0;
}
return true;
}
bool AxisUtil::GetAxisValueByNCHW(const vector<int64_t> &originalDimVec, const uint32_t &c0, vector<int64_t> &axisValue,
vector<int64_t> &ndValue) {
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true);
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true);
/* C0 Must be set for case ND or 2D-NCHW to NZ */
axisValue[AXIS_C0] = c0;
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED,"[ERROR]Parameter is invalid!"),
return false);
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N];
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C];
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H];
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W];
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0);
axisValue[AXIS_Co] = c0;
return true;
}
bool AxisUtil::GetAxisValueByNHWC(const vector<int64_t> &originalDimVec, const uint32_t &c0, vector<int64_t> &axisValue,
vector<int64_t> &ndValue) {
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true);
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true);
/* C0 Must be set for case ND or 2D-NHWC to NZ */
axisValue[AXIS_C0] = c0;
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED, "[ERROR]Parameter is invalid!"),
return false);
axisValue[AXIS_N] = originalDimVec[AXIS_NHWC_DIM_N];
axisValue[AXIS_C] = originalDimVec[AXIS_NHWC_DIM_C];
axisValue[AXIS_H] = originalDimVec[AXIS_NHWC_DIM_H];
axisValue[AXIS_W] = originalDimVec[AXIS_NHWC_DIM_W];
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NHWC_DIM_C], (int64_t)c0);
axisValue[AXIS_Co] = c0;
return true;
}
bool AxisUtil::GetAxisValueByNC1HWC0(const vector<int64_t> &originalDimVec, const uint32_t &c0,
vector<int64_t> &axisValue, vector<int64_t> &ndValue) {
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true);
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true);
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED,"[ERROR]Parameter is invalid!"),
return false);
auto dimSize = originalDimVec.size();
if (dimSize == DIM_DEFAULT_SIZE + 1) {
axisValue[AXIS_C1] = originalDimVec[AXIS_NC1HWC0_DIM_C1];
axisValue[AXIS_C0] = originalDimVec[AXIS_NC1HWC0_DIM_C0];
axisValue[AXIS_C] = axisValue[AXIS_C1] * axisValue[AXIS_C0];
} else {
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_NCHW_DIM_C], (int64_t)c0);
axisValue[AXIS_C0] = c0;
axisValue[AXIS_C] = originalDimVec[AXIS_NCHW_DIM_C];
}
axisValue[AXIS_N] = originalDimVec[AXIS_NCHW_DIM_N];
axisValue[AXIS_H] = originalDimVec[AXIS_NCHW_DIM_H];
axisValue[AXIS_W] = originalDimVec[AXIS_NCHW_DIM_W];
return true;
}
bool AxisUtil::GetAxisValueByHWCN(const vector<int64_t> &originalDimVec, const uint32_t &c0, vector<int64_t> &axisValue,
vector<int64_t> &ndValue) {
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true);
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true);
/* C0 Must be set for case ND or 2D-NHWC to NZ */
axisValue[AXIS_C0] = c0;
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED, "[ERROR]Parameter is invalid!"),
return false);
axisValue[AXIS_N] = originalDimVec[AXIS_HWCN_DIM_N];
axisValue[AXIS_C] = originalDimVec[AXIS_HWCN_DIM_C];
axisValue[AXIS_H] = originalDimVec[AXIS_HWCN_DIM_H];
axisValue[AXIS_W] = originalDimVec[AXIS_HWCN_DIM_W];
axisValue[AXIS_C1] = DivisionCeiling(originalDimVec[AXIS_HWCN_DIM_C], (int64_t)c0);
axisValue[AXIS_Co] = c0;
return true;
}
bool AxisUtil::GetAxisValueByC1HWNCoC0(const vector<int64_t> &originalDimVec, const uint32_t &c0,
vector<int64_t> &axisValue, vector<int64_t> &ndValue) {
CHECK(axisValue.empty(), GELOGI("AxisValue is empty!"), return true);
CHECK(originalDimVec.empty(), GELOGI("Original dim vector is empty!"), return true);
/* C0 Must be set for case ND or 2D-NHWC to NZ */
axisValue[AXIS_C0] = c0;
CHECK(CheckParams(originalDimVec, c0, axisValue, ndValue) != true, GELOGE(GRAPH_FAILED, "[ERROR]Parameter is invalid!"),
return false);
axisValue[AXIS_N] = originalDimVec[AXIS_C1HWNCoC0_DIM_N];
axisValue[AXIS_C] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1] * c0;
axisValue[AXIS_H] = originalDimVec[AXIS_C1HWNCoC0_DIM_H];
axisValue[AXIS_W] = originalDimVec[AXIS_C1HWNCoC0_DIM_W];
axisValue[AXIS_C1] = originalDimVec[AXIS_C1HWNCoC0_DIM_C1];
axisValue[AXIS_Co] = originalDimVec[AXIS_C1HWNCoC0_DIM_Co];
return true;
}
} // namespace transformer
} // namespace common

@ -0,0 +1,242 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*!
* \file transfer_shape_according_to_format.cpp
* \brief set shape according to original format and current format
*/
#include "transformer/inc/transfer_shape_according_to_format.h"
namespace common {
namespace transformer {
using namespace ge;
using namespace std;
ShapeTransferAccordingToFormat::ShapeTransferAccordingToFormat(void) {
getNewShapeFuncMap = {
{ge::FORMAT_NCHW, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNCHWShapeByAxisValue)},
{ge::FORMAT_NHWC, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNHWCShapeByAxisValue)},
{ge::FORMAT_NC1HWC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNC1HWC0ShapeByAxisValue)},
{ge::FORMAT_FRACTAL_Z, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetFzShapeByAxisValue)},
{ge::FORMAT_HWCN, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetHWCNShapeByAxisValue)},
{ge::FORMAT_C1HWNCoC0, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetC1HWNCoC0ShapeByAxisValue)},
{ge::FORMAT_FRACTAL_NZ, std::make_shared<GetNewShapeByAxisValueAndFormat>(GetNzShapeByAxisValue)}};
mapOfDtypeAndC0 = {
{ge::DT_FLOAT16, SHAPE_NUMBER_16}, {ge::DT_FLOAT, SHAPE_NUMBER_16}, {ge::DT_INT8, SHAPE_NUMBER_32},
{ge::DT_INT16, SHAPE_NUMBER_16}, {ge::DT_INT32, SHAPE_NUMBER_16}, {ge::DT_INT64, SHAPE_NUMBER_16},
{ge::DT_UINT8, SHAPE_NUMBER_16}, {ge::DT_UINT16, SHAPE_NUMBER_32}, {ge::DT_UINT32, SHAPE_NUMBER_16},
{ge::DT_UINT64, SHAPE_NUMBER_16}, {ge::DT_BOOL, SHAPE_NUMBER_16}};
}
bool ShapeTransferAccordingToFormat::GetNCHWShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType,
const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue) {
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true);
/* axisValue is initialized as a size 6 vector. */
newShape.push_back(axisValue[AXIS_N]);
newShape.push_back(axisValue[AXIS_C]);
newShape.push_back(axisValue[AXIS_H]);
newShape.push_back(axisValue[AXIS_W]);
return true;
}
bool ShapeTransferAccordingToFormat::GetNHWCShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType,
const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue) {
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true);
/* axisValue is initialized as a size 6 vector. */
newShape.push_back(axisValue[AXIS_N]);
newShape.push_back(axisValue[AXIS_H]);
newShape.push_back(axisValue[AXIS_W]);
newShape.push_back(axisValue[AXIS_C]);
return true;
}
bool ShapeTransferAccordingToFormat::GetNC1HWC0ShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType,
const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue) {
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true);
/* axisValue is initialized as a size 6 vector. */
if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) {
newShape.push_back(axisValue[AXIS_N]);
newShape.push_back(axisValue[AXIS_C1]);
newShape.push_back(axisValue[AXIS_H]);
newShape.push_back(axisValue[AXIS_W]);
newShape.push_back(axisValue[AXIS_C0]);
} else {
newShape.push_back(axisValue[AXIS_N]);
newShape.push_back(axisValue[AXIS_C]);
newShape.push_back(axisValue[AXIS_H]);
newShape.push_back(axisValue[AXIS_W]);
}
return true;
}
bool ShapeTransferAccordingToFormat::GetFzShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType,
const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue) {
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true);
/* axisValue is initialized as a size 6 vector. */
if (ndValue.size() == SIZE_OF_CN) {
auto sizeOfOriginalVec = ndValue.size();
newShape = ndValue;
/* sizeOfOriginalVec - 1 mean the last value of original vec
* sizeOfOriginalVec - 2 mean the second last value of original vec */
newShape[sizeOfOriginalVec - MINUS_VALUE_ONE] =
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], SHAPE_NUMBER_16);
newShape[sizeOfOriginalVec - MINUS_VALUE_TWO] =
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], axisValue[AXIS_C0]);
newShape.push_back(SHAPE_NUMBER_16);
newShape.push_back(axisValue[AXIS_C0]);
} else {
if (implType == EN_IMPL_HW_TBE || implType == EN_IMPL_CUSTOM_TBE || implType == EN_IMPL_NON_PERSISTENT_CUSTOM_TBE) {
int64_t hwc1 = axisValue[AXIS_C1] * axisValue[AXIS_H] * axisValue[AXIS_W];
newShape.push_back(hwc1);
newShape.push_back(DivisionCeiling(axisValue[AXIS_N], NI));
newShape.push_back(NI);
newShape.push_back(axisValue[AXIS_C0]);
} else {
newShape.push_back(axisValue[AXIS_N]);
newShape.push_back(axisValue[AXIS_C]);
newShape.push_back(axisValue[AXIS_H]);
newShape.push_back(axisValue[AXIS_W]);
}
}
return true;
}
bool ShapeTransferAccordingToFormat::GetHWCNShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType,
const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue) {
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true);
/* axisValue is initialized as a size 6 vector. */
newShape.push_back(axisValue[AXIS_H]);
newShape.push_back(axisValue[AXIS_W]);
newShape.push_back(axisValue[AXIS_C]);
newShape.push_back(axisValue[AXIS_N]);
return true;
}
bool ShapeTransferAccordingToFormat::GetC1HWNCoC0ShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType,
const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue) {
CHECK(axisValue.empty(), GELOGD("AxisValue is empty!"), return true);
/* axisValue is initialized as a size 6 vector. */
newShape.push_back(axisValue[AXIS_C1]);
newShape.push_back(axisValue[AXIS_H]);
newShape.push_back(axisValue[AXIS_W]);
newShape.push_back(axisValue[AXIS_N]);
newShape.push_back(axisValue[AXIS_Co]);
newShape.push_back(axisValue[AXIS_C0]);
return true;
}
bool ShapeTransferAccordingToFormat::GetNzShapeByAxisValue(vector<int64_t>& newShape, const int64_t& implType,
const vector<int64_t>& axisValue,
const vector<int64_t>& ndValue) {
CHECK(ndValue.empty(), GELOGD("ndValue is empty!"), return true);
CHECK(axisValue.empty() || axisValue.size() <= AXIS_C0,
GELOGD("AxisValue is empty or its size %zu <= AXIS_C0[%u]", axisValue.size(), AXIS_C0), return true);
uint32_t sizeOfOriginalVec = ndValue.size();
if (sizeOfOriginalVec < MINIMUM_NZ_SHAPE_DIM_NUM) {
GELOGD("ndValue's dim num is less than 2!");
return true;
}
/* axisValue is initialized as a size 6 vector. */
newShape = ndValue;
/* sizeOfOriginalVec - 1 mean the last value of original vec
* sizeOfOriginalVec - 2 mean the second last value of original vec */
newShape[sizeOfOriginalVec - MINUS_VALUE_ONE] =
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_TWO], (int64_t)SHAPE_NUMBER_16);
newShape[sizeOfOriginalVec - MINUS_VALUE_TWO] =
DivisionCeiling(ndValue[sizeOfOriginalVec - MINUS_VALUE_ONE], axisValue[AXIS_C0]);
newShape.push_back(SHAPE_NUMBER_16);
newShape.push_back(axisValue[AXIS_C0]);
return true;
}
bool ShapeTransferAccordingToFormat::GetShapeAccordingToFormat(ShapeAndFormat& shapeAndFormatInfo, int64_t* c) {
/* The default new shape is old shape */
shapeAndFormatInfo.newShape = shapeAndFormatInfo.oldShape;
if (shapeAndFormatInfo.oldFormat >= ge::FORMAT_RESERVED || shapeAndFormatInfo.newFormat >= ge::FORMAT_RESERVED) {
GELOGE(GRAPH_FAILED, "Old format %u or new format %u is invalid!", shapeAndFormatInfo.oldFormat,
shapeAndFormatInfo.newFormat);
return false;
}
if (shapeAndFormatInfo.currentDataType >= ge::DT_UNDEFINED) {
GELOGE(GRAPH_FAILED, "currentDataType %u is invalid!", shapeAndFormatInfo.currentDataType);
return false;
}
AxisUtil* axisutil_object = new AxisUtil();
if (!axisutil_object->HasAxisValueFunc(shapeAndFormatInfo.oldFormat)) {
delete axisutil_object;
return true;
}
auto iterGetNewShapeFunc = getNewShapeFuncMap.find(shapeAndFormatInfo.newFormat);
if (iterGetNewShapeFunc == getNewShapeFuncMap.end()) {
GELOGD("Can not get new shape of new format %u!", shapeAndFormatInfo.newFormat);
delete axisutil_object;
return true;
}
GELOGD("Original format %u, new format %u", shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.newFormat);
GetNewShapeByAxisValueAndFormatPtr getNewShapeFunc = iterGetNewShapeFunc->second;
CHECK_NOTNULL(getNewShapeFunc);
std::vector<int64_t> axisValue;
for (uint32_t i = 0; i < AXIS_BOTTOM; i++) {
axisValue.push_back(1);
}
std::vector<int64_t> ndValue;
uint32_t c0;
if (mapOfDtypeAndC0.empty()) {
c0 = SHAPE_NUMBER_16;
} else {
auto iterGetC0 = mapOfDtypeAndC0.find(shapeAndFormatInfo.currentDataType);
if (iterGetC0 == mapOfDtypeAndC0.end()) {
GELOGE(GRAPH_FAILED, "Dtype is not support.");
delete axisutil_object;
return true;
}
c0 = iterGetC0->second;
}
// The value of C0 should be 4 while format is 5HD-4 or FRAZ-4
if (shapeAndFormatInfo.newFormat == ge::FORMAT_NC1HWC0_C04) {
c0 = SHAPE_DIM_VALUE_C04;
}
bool status = axisutil_object->GetAxisValueByOriginFormat(
shapeAndFormatInfo.oldFormat, shapeAndFormatInfo.oldShape, c0, axisValue, ndValue);
if (status != true && shapeAndFormatInfo.newFormat != ge::FORMAT_FRACTAL_NZ) {
delete axisutil_object;
return true;
}
delete axisutil_object;
shapeAndFormatInfo.newShape.clear();
(*getNewShapeFunc)(shapeAndFormatInfo.newShape, shapeAndFormatInfo.opImplType, axisValue, ndValue);
if (c != nullptr) {
*c = axisValue[AXIS_C];
}
return true;
}
} // namespace transformer
} // namespace common

@ -0,0 +1,160 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "transformer_utils.h"
#include "external/ge/ge_api_types.h"
#include "framework/common/debug/ge_log.h"
#include "graph/utils/type_utils.h"
namespace ge {
bool NodeShapeTransUtils::CatchFormatAndShape() {
inputs_ = op_desc_->GetAllInputName();
outputs_ = op_desc_->GetAllOutputName();
for (auto &ele : inputs_) {
auto tensor_desc_input = op_desc_->MutableInputDesc(ele.first);
if (tensor_desc_input == nullptr) {
continue;
}
auto format = tensor_desc_input->GetFormat();
auto ori_format = tensor_desc_input->GetOriginFormat();
if (format == ori_format) {
GELOGD("Node is %s, input tensor name is %s. ori format: %s, format: %s is same! No need to catch format&shape!",
op_desc_->GetName().c_str(), ele.first.c_str(), TypeUtils::FormatToSerialString(ori_format).c_str(),
TypeUtils::FormatToSerialString(format).c_str());
continue;
}
map_format_in_.insert(std::pair<std::string, Format>(ele.first, format));
map_ori_format_in_.insert(std::pair<std::string, Format>(ele.first, ori_format));
map_dtype_in_.insert(std::pair<std::string, DataType>(ele.first, tensor_desc_input->GetDataType()));
tensor_desc_input->SetFormat(ori_format);
tensor_desc_input->SetShape(tensor_desc_input->GetOriginShape());
}
for (auto &ele : outputs_) {
auto tensor_desc_output = op_desc_->MutableOutputDesc(ele.first);
if (tensor_desc_output == nullptr) {
continue;
}
auto format = tensor_desc_output->GetFormat();
auto ori_format = tensor_desc_output->GetOriginFormat();
if (format == ori_format) {
GELOGD("Node is %s, output tensor name is %s. ori format: %s, format: %s is same! No need to catch format&shape!",
op_desc_->GetName().c_str(), ele.first.c_str(), TypeUtils::FormatToSerialString(ori_format).c_str(),
TypeUtils::FormatToSerialString(format).c_str());
continue;
}
map_format_out_.insert(std::pair<std::string, Format>(ele.first, format));
map_ori_format_out_.insert(std::pair<std::string, Format>(ele.first, ori_format));
map_dtype_out_.insert(std::pair<std::string, DataType>(ele.first, tensor_desc_output->GetDataType()));
if (format == ori_format) {
continue;
}
tensor_desc_output->SetFormat(ori_format);
}
return true;
}
bool NodeShapeTransUtils::UpdateFormatAndShape() {
for (auto &ele : inputs_) {
auto tensor_desc_input = op_desc_->MutableInputDesc(ele.first);
if (tensor_desc_input == nullptr) {
continue;
}
// if can not find saved info, it says format and origin format is same when catched
if (map_format_in_.find(ele.first) == map_format_in_.end()) {
GELOGD("Node is [%s], input tensor name [%s] is not been catched.Skip update action for it!",
op_desc_->GetName().c_str(), ele.first.c_str());
tensor_desc_input->SetOriginFormat(tensor_desc_input->GetFormat());
tensor_desc_input->SetOriginShape(tensor_desc_input->GetShape());
continue;
}
auto ori_format = tensor_desc_input->GetFormat();
auto ori_shape = tensor_desc_input->GetShape();
auto curr_format = map_format_in_[ele.first];
if (ori_format == curr_format) {
continue;
}
std::unique_ptr<common::transformer::ShapeTransferAccordingToFormat> shape_transfer(
new (std::nothrow) common::transformer::ShapeTransferAccordingToFormat());
if (shape_transfer == nullptr) {
GELOGE(GRAPH_FAILED, "Memory alloc failed");
return false;
}
std::vector<int64_t> ori_shape_dims = ori_shape.GetDims();
std::vector<int64_t> out_dims;
ge::DataType dtype = map_dtype_in_[ele.first];
common::transformer::ShapeAndFormat shape_and_format_info{
ori_shape_dims, out_dims, ori_format, curr_format, dtype, common::transformer::EN_IMPL_CUSTOM_TBE};
shape_transfer->GetShapeAccordingToFormat(shape_and_format_info);
tensor_desc_input->SetFormat(curr_format);
tensor_desc_input->SetShape(GeShape(out_dims));
}
for (auto &ele : outputs_) {
auto tensor_desc_output = op_desc_->MutableOutputDesc(ele.first);
if (tensor_desc_output == nullptr) {
continue;
}
// if can not find saved info, it says format and origin format is same when catched
if (map_ori_format_out_.find(ele.first) == map_ori_format_out_.end()) {
GELOGD("Node is [%s], input tensor name [%s] is not been catched.Skip update action for it!",
op_desc_->GetName().c_str(), ele.first.c_str());
tensor_desc_output->SetOriginFormat(tensor_desc_output->GetFormat());
tensor_desc_output->SetOriginShape(tensor_desc_output->GetShape());
continue;
}
auto ori_shape = tensor_desc_output->GetShape();
auto curr_format = tensor_desc_output->GetFormat();
if (curr_format != map_ori_format_out_[ele.first]) {
GELOGE(GRAPH_FAILED, "Node is %s, out tensor name is %s. format: %s, recorded origin format: %s is not same",
op_desc_->GetName().c_str(), ele.first.c_str(), TypeUtils::FormatToSerialString(curr_format).c_str(),
TypeUtils::FormatToSerialString(map_ori_format_out_[ele.first]).c_str());
return GRAPH_FAILED;
}
tensor_desc_output->SetOriginShape(ori_shape);
auto saved_format = map_format_out_[ele.first];
if (curr_format == saved_format) {
GELOGD("Nodeis %s, out tensor name is %s. ori format: %s, recorded format: %s is same! No need to transfer",
op_desc_->GetName().c_str(), ele.first.c_str(), TypeUtils::FormatToSerialString(curr_format).c_str(),
TypeUtils::FormatToSerialString(saved_format).c_str());
continue;
}
tensor_desc_output->SetFormat(saved_format);
std::unique_ptr<common::transformer::ShapeTransferAccordingToFormat> shape_transfer(
new (std::nothrow) common::transformer::ShapeTransferAccordingToFormat());
if (shape_transfer == nullptr) {
GELOGE(GRAPH_FAILED, "Memory alloc failed");
return false;
}
std::vector<int64_t> ori_shape_dims = ori_shape.GetDims();
std::vector<int64_t> out_dims;
ge::DataType dtype = tensor_desc_output->GetDataType();
common::transformer::ShapeAndFormat shape_and_format_info{
ori_shape_dims, out_dims, curr_format, saved_format, dtype, common::transformer::EN_IMPL_CUSTOM_TBE};
shape_transfer->GetShapeAccordingToFormat(shape_and_format_info);
tensor_desc_output->SetShape(GeShape(out_dims));
GELOGD("Node is %s, out tensor name is %s. Update format and shape successori format: %s, format: %s",
op_desc_->GetName().c_str(), ele.first.c_str(), TypeUtils::FormatToSerialString(curr_format).c_str(),
TypeUtils::FormatToSerialString(saved_format).c_str());
}
GELOGD("Node is %s. Update format and shape success", op_desc_->GetName().c_str());
return true;
}
} // namespace ge

@ -0,0 +1,50 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef COMMON_GRAPH_UTILS_TRANSFORMER_UTILS_H_
#define COMMON_GRAPH_UTILS_TRANSFORMER_UTILS_H_
#include <string>
#include <map>
#include "external/graph/types.h"
#include "graph/op_desc.h"
#include "graph/ge_tensor.h"
#include "transformer/inc/transfer_shape_according_to_format.h"
namespace ge {
class NodeShapeTransUtils {
public:
bool CatchFormatAndShape();
bool UpdateFormatAndShape();
explicit NodeShapeTransUtils(OpDescPtr op_desc) : op_desc_(op_desc) {}
~NodeShapeTransUtils() {}
private:
std::map<std::string, Format> map_format_in_;
std::map<std::string, Format> map_ori_format_in_;
std::map<std::string, DataType> map_dtype_in_;
std::map<std::string, Format> map_format_out_;
std::map<std::string, Format> map_ori_format_out_;
std::map<std::string, DataType> map_dtype_out_;
std::map<std::string, uint32_t> inputs_;
std::map<std::string, uint32_t> outputs_;
OpDescPtr op_desc_;
};
} // namespace ge
#endif // COMMON_GRAPH_UTILS_TRANSFORMER_UTILS_H_

@ -260,6 +260,33 @@ Status Session::AddGraph(uint32_t graph_id, const Graph &graph, const std::map<s
return ret; return ret;
} }
Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph) {
std::map<AscendString, AscendString> options;
return AddGraphWithCopy(graph_id, graph, options);
}
Status Session::AddGraphWithCopy(uint32_t graph_id, const Graph &graph,
const std::map<AscendString, AscendString> &options) {
GELOGT(TRACE_INIT, "Start to add graph in Session. graph_id: %u, session_id: %lu.", graph_id, sessionId_);
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance();
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) {
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "AddGraph failed in Session.");
return FAILED;
}
std::map<std::string, std::string> str_options;
for (auto it = options.begin(); it != options.end(); ++it) {
str_options.insert({it->first.GetString(), it->second.GetString()});
}
GELOGD("Adding graph to session");
Status ret = instance_ptr->SessionManagerObj().AddGraphWithCopy(sessionId_, graph_id, graph, str_options);
if (ret != SUCCESS) {
GELOGE(ret, "AddGraph failed in Session.");
return FAILED;
}
GELOGD("AddGraph finished in Session.");
return ret;
}
Status Session::RemoveGraph(uint32_t graph_id) { Status Session::RemoveGraph(uint32_t graph_id) {
GELOGT(TRACE_INIT, "Session RemoveGraph start"); GELOGT(TRACE_INIT, "Session RemoveGraph start");

@ -24,6 +24,7 @@
#include "common/fp16_t.h" #include "common/fp16_t.h"
#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"
#include "securec.h" #include "securec.h"
@ -123,21 +124,25 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result
std::pair<DataType, DataType> trans_info(args.src_data_type, args.dst_data_type); std::pair<DataType, DataType> trans_info(args.src_data_type, args.dst_data_type);
auto iter = trans_mode_map.find(trans_info); auto iter = trans_mode_map.find(trans_info);
if (iter == trans_mode_map.end()) { if (iter == trans_mode_map.end()) {
GELOGE(PARAM_INVALID, "Trans data type from %s to %s is not supported.", std::string error = "Failed to trans data from datatype " +
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " +
TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + " , it is not supported.";
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str());
return UNSUPPORTED; return UNSUPPORTED;
} }
auto trans_mode = iter->second; auto trans_mode = iter->second;
int size = GetSizeByDataType(args.dst_data_type); int size = GetSizeByDataType(args.dst_data_type);
if (size <= 0) { if (size <= 0) {
GELOGE(PARAM_INVALID, "Failed to calc size from data type %s", std::string error = "Failed to calc size from data type" +
TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + ", it is not supported.";
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
return PARAM_INVALID; return PARAM_INVALID;
} }
if (args.src_data_size > static_cast<size_t>(SIZE_MAX / size)) { if (args.src_data_size > static_cast<size_t>(SIZE_MAX / size)) {
GELOGE(PARAM_INVALID, "args.src_data_size %zu or data type size %d too big.", args.src_data_size, size); std::string error =
"args.src_data_size" + FmtToStr(args.src_data_size) + " or data type size" + FmtToStr(size) + " is too big";
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
return PARAM_INVALID; return PARAM_INVALID;
} }
size_t total_size = static_cast<size_t>(args.src_data_size * size); size_t total_size = static_cast<size_t>(args.src_data_size * size);
@ -154,9 +159,11 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result
} }
if (CastKernel(args, dst.get(), args.src_data_size, trans_mode) != SUCCESS) { if (CastKernel(args, dst.get(), args.src_data_size, trans_mode) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to cast data from %s to %s, data size %zu", std::string error = "Failed to cast data from datatype " +
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type)) + " to " +
TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str(), args.src_data_size); FmtToStr(TypeUtils::DataTypeToSerialString(args.dst_data_type)) + ", data size is " +
FmtToStr(std::to_string(args.src_data_size));
GE_ERRORLOG_AND_ERRORMSG(INTERNAL_ERROR, error.c_str());
return INTERNAL_ERROR; return INTERNAL_ERROR;
} }
result.data = dst; result.data = dst;

@ -22,6 +22,7 @@
#include "common/formats/utils/formats_definitions.h" #include "common/formats/utils/formats_definitions.h"
#include "common/formats/utils/formats_trans_utils.h" #include "common/formats/utils/formats_trans_utils.h"
#include "framework/common/debug/ge_log.h" #include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"
namespace ge { namespace ge {
@ -35,14 +36,16 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) {
auto src_shape = args.src_shape; auto src_shape = args.src_shape;
auto dst_shape = args.dst_shape; auto dst_shape = args.dst_shape;
if (args.src_format != FORMAT_C1HWNCoC0 || args.dst_format != FORMAT_HWCN) { if (args.src_format != FORMAT_C1HWNCoC0 || args.dst_format != FORMAT_HWCN) {
GELOGE(UNSUPPORTED, "Does not support trans format from %s to %s", std::string error = "Dose not support trans format from " +
TypeUtils::FormatToSerialString(args.src_format).c_str(), FmtToStr(TypeUtils::FormatToSerialString(args.src_format)) + " to " +
TypeUtils::FormatToSerialString(args.dst_format).c_str()); FmtToStr(TypeUtils::FormatToSerialString(args.dst_format));
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str());
return UNSUPPORTED; return UNSUPPORTED;
} }
if (!CheckDataTypeSupported(args.src_data_type)) { if (!CheckDataTypeSupported(args.src_data_type)) {
GELOGE(UNSUPPORTED, "Failed to trans shape from NC1HWNCoC0 to HWCN, invalid data type %s", std::string error = "Failed to trans shape from NC1HWNCoC0 to HWCN, invalid data type" +
TypeUtils::DataTypeToSerialString(args.src_data_type).c_str()); FmtToStr(TypeUtils::DataTypeToSerialString(args.src_data_type));
GE_ERRORLOG_AND_ERRORMSG(UNSUPPORTED, error.c_str());
return UNSUPPORTED; return UNSUPPORTED;
} }
if (!CheckShapeValid(src_shape, kC1hwncoc0DimsNum)) { if (!CheckShapeValid(src_shape, kC1hwncoc0DimsNum)) {
@ -58,8 +61,9 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) {
src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) || src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) ||
src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != cube_size || src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != cube_size ||
src_shape.at(kC1hwncoc0C0) != cube_size) { src_shape.at(kC1hwncoc0C0) != cube_size) {
GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", std::string error = "Failed to check relationship between src and dst shape, src shape" +
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); FmtToStr(ShapeToString(src_shape)) + ", dst shape" + FmtToStr(ShapeToString(dst_shape));
GE_ERRORLOG_AND_ERRORMSG(PARAM_INVALID, error.c_str());
return PARAM_INVALID; return PARAM_INVALID;
} }

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save