|
|
|
@ -1084,12 +1084,14 @@ std::string TbeKernelBuild::GetNodeFusionType(const mindspore::CNodePtr &cnode)
|
|
|
|
|
{kDepthwiseConv2dNativeOpName, "DepthwiseConvolution"},
|
|
|
|
|
{kAddNOpName, "ElemWise"},
|
|
|
|
|
{kReluGradV2OpName, "ElemWise"},
|
|
|
|
|
{kRealDivOpName, "ElemWise"}};
|
|
|
|
|
{kRealDivOpName, "ElemWise"},
|
|
|
|
|
{kBiasAddOpName, "BiasAdd"}};
|
|
|
|
|
auto find = fusion_type_map.find(node_type);
|
|
|
|
|
if (find == fusion_type_map.end()) {
|
|
|
|
|
MS_LOG(INFO) << "Fusion warning: get node fusion type failed, origin node type: " << node_type
|
|
|
|
|
<< " return null string.";
|
|
|
|
|
return "";
|
|
|
|
|
MS_LOG(INFO) << "Fusion warning: get node fusion type failed from lists, origin node type: " << node_type;
|
|
|
|
|
auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(node_type, cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_info);
|
|
|
|
|
return op_info->fusion_type();
|
|
|
|
|
} else {
|
|
|
|
|
return find->second;
|
|
|
|
|
}
|
|
|
|
|