|
|
@ -35,6 +35,7 @@ enum MatchCountPriority : int {
|
|
|
|
MATCH_COUNT_PRIORITY_BEGIN = 0,
|
|
|
|
MATCH_COUNT_PRIORITY_BEGIN = 0,
|
|
|
|
MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
|
|
|
|
MATCH_DTYPE_COUNT = MATCH_COUNT_PRIORITY_BEGIN,
|
|
|
|
MATCH_FORMAT_COUNT,
|
|
|
|
MATCH_FORMAT_COUNT,
|
|
|
|
|
|
|
|
MATCH_SPECIAL_FORMAT_COUNT,
|
|
|
|
MATCH_5D_FORMAT_COUNT,
|
|
|
|
MATCH_5D_FORMAT_COUNT,
|
|
|
|
MATCH_OUTPUT_DTYPE_COUNT,
|
|
|
|
MATCH_OUTPUT_DTYPE_COUNT,
|
|
|
|
MATCH_COUNT_PRIORITY_END
|
|
|
|
MATCH_COUNT_PRIORITY_END
|
|
|
@ -81,6 +82,12 @@ bool IsValidKernelInfo(const std::shared_ptr<CNode> &kernel_node, const kernel::
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
if (AnfAlgo::GetCNodeName(kernel_node) == "LayerNormBetaGammaBackprop" ||
|
|
|
|
|
|
|
|
AnfAlgo::GetCNodeName(kernel_node) == "LayerNormXBackprop") {
|
|
|
|
|
|
|
|
if (AnfAlgo::GetPrevNodeOutputFormat(kernel_node, 0) != kernel_build_info.GetInputFormat(0)) {
|
|
|
|
|
|
|
|
return true;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) {
|
|
|
|
if (AnfAlgo::GetCNodeName(kernel_node) == prim::kPrimCast->name()) {
|
|
|
|
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) &&
|
|
|
|
return AnfAlgo::GetOutputInferDataType(kernel_node, 0) == kernel_build_info.GetOutputDeviceType(0) &&
|
|
|
|
AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0);
|
|
|
|
AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, 0) == kernel_build_info.GetInputDeviceType(0);
|
|
|
@ -154,7 +161,7 @@ bool PriorityChooseItem(const std::vector<int> &cur_item, std::vector<int> *best
|
|
|
|
return false;
|
|
|
|
return false;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return false;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node,
|
|
|
|
void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, const std::shared_ptr<CNode> &kernel_node,
|
|
|
@ -174,12 +181,11 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (input_anf_node->isa<ValueNode>()) {
|
|
|
|
|
|
|
|
if (AnfAlgo::GetOutputDeviceDataType(input_anf_node, 0) == kTypeUnknown) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
|
|
|
|
if (kernel_build_info.GetInputFormat(input_index) == AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)) {
|
|
|
|
|
|
|
|
if (AnfAlgo::IsFeatureMapInput(kernel_node, input_index) &&
|
|
|
|
|
|
|
|
kSpecialFormatSet.find(kernel_build_info.GetInputFormat(input_index)) != kSpecialFormatSet.end()) {
|
|
|
|
|
|
|
|
(*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT]++;
|
|
|
|
|
|
|
|
}
|
|
|
|
(*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++;
|
|
|
|
(*cur_kernelinfo_match_counts)[MATCH_FORMAT_COUNT]++;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (kernel_build_info.GetInputDeviceType(input_index) ==
|
|
|
|
if (kernel_build_info.GetInputDeviceType(input_index) ==
|
|
|
@ -203,7 +209,7 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons
|
|
|
|
(*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++;
|
|
|
|
(*cur_kernelinfo_match_counts)[MATCH_OUTPUT_DTYPE_COUNT]++;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
|
|
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
|
|
|
|
void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, const CNodePtr &kernel_node) {
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
|
|
|