fix compile error

pull/1115/head
wangcong 5 years ago
parent e477e7617b
commit 453514dd51

@ -721,7 +721,7 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i
return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size()); return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size());
} }
std::string TbeKernelBuild::GetRealOpType(std::string &origin_type) { std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) {
static std::map<std::string, std::string> buffer_fussion_op_map = {{"DepthwiseConv2dNative", "DepthwiseConv2D"}, static std::map<std::string, std::string> buffer_fussion_op_map = {{"DepthwiseConv2dNative", "DepthwiseConv2D"},
{"TensorAdd", "Add"}}; {"TensorAdd", "Add"}};
string result = origin_type; string result = origin_type;
@ -834,9 +834,9 @@ bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_n
} }
(*compute_op_str)["output_desc"] = output_desc_list; (*compute_op_str)["output_desc"] = output_desc_list;
// gen others // gen others
auto type = AnfAlgo::GetCNodeName(cnode); auto origin_type = AnfAlgo::GetCNodeName(cnode);
// replace special op type for buffer fusion op // replace special op type for buffer fusion op
type = GetRealOpType(type); auto type = GetRealOpType(origin_type);
(*compute_op_str)["type"] = type; (*compute_op_str)["type"] = type;
tbe::TbeAdapter::NormalizeFuncName(&type); tbe::TbeAdapter::NormalizeFuncName(&type);
(*compute_op_str)["func_name"] = type; (*compute_op_str)["func_name"] = type;

@ -76,7 +76,7 @@ class TbeKernelBuild {
std::map<const AnfNodePtr, FusionDataType> *spec_data_input); std::map<const AnfNodePtr, FusionDataType> *spec_data_input);
static bool IsDynamicInput(const CNodePtr &cnode); static bool IsDynamicInput(const CNodePtr &cnode);
static size_t GetOptionalInput(const CNodePtr &cnode, bool is_dynamic_input); static size_t GetOptionalInput(const CNodePtr &cnode, bool is_dynamic_input);
std::string GetRealOpType(std::string &origin_type); static std::string GetRealOpType(const std::string &origin_type);
}; };
class TbeKernelJsonCreator { class TbeKernelJsonCreator {

Loading…
Cancel
Save