|
|
|
@ -52,6 +52,17 @@ void UpdatePrevNotCNodeFormatDtype(const KernelAttr &kernel_attr, const std::vec
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GetOutputInferFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::string> *output_formats,
|
|
|
|
|
std::vector<TypeId> *output_types) {
|
|
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
|
|
|
|
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
|
|
|
|
TypeId dtype = kTypeUnknown;
|
|
|
|
|
dtype = AnfAlgo::GetOutputInferDataType(kernel_node, output_index);
|
|
|
|
|
output_formats->emplace_back(kOpFormat_DEFAULT);
|
|
|
|
|
output_types->emplace_back(dtype);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::string> *input_formats,
|
|
|
|
|
std::vector<TypeId> *input_types, std::vector<size_t> *input_no_cnode_indexes) {
|
|
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
|
|
|
@ -78,10 +89,53 @@ void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &ke
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool InputDtypeMatch(TypeId InputAttr, TypeId input_type, bool strict) {
|
|
|
|
|
if (InputAttr == input_type) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
if (!strict && InputAttr == kNumberTypeInt32 && (input_type == kNumberTypeInt16 || input_type == kNumberTypeInt64)) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
if (!strict && InputAttr == kNumberTypeFloat32 &&
|
|
|
|
|
(input_type == kNumberTypeFloat16 || input_type == kNumberTypeFloat64)) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::pair<int, int> GetOutputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
|
|
|
|
|
const std::vector<std::string> &output_formats,
|
|
|
|
|
const std::vector<TypeId> &output_types) {
|
|
|
|
|
if (kernel_attr.GetOutputSize() != output_types.size()) {
|
|
|
|
|
MS_LOG(DEBUG) << "required output num:" << kernel_attr.GetInputSize()
|
|
|
|
|
<< ", actual output num:" << output_types.size();
|
|
|
|
|
return std::make_pair(0, 0);
|
|
|
|
|
}
|
|
|
|
|
int data_type_matched_num = 0;
|
|
|
|
|
int format_matched_num = 0;
|
|
|
|
|
auto output_num = output_types.size();
|
|
|
|
|
for (size_t i = 0; i < output_num; ++i) {
|
|
|
|
|
if (kernel_attr.GetOutputAttr(i).first != output_types[i]) {
|
|
|
|
|
MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetOutputAttr(i).first
|
|
|
|
|
<< ", actual output dtype:" << output_types[i];
|
|
|
|
|
} else {
|
|
|
|
|
data_type_matched_num++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (kernel_attr.GetOutputAttr(i).second != output_formats[i]) {
|
|
|
|
|
MS_LOG(DEBUG) << "required format:" << kernel_attr.GetOutputAttr(i).second
|
|
|
|
|
<< ", actual output format:" << output_formats[i];
|
|
|
|
|
} else {
|
|
|
|
|
format_matched_num++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return std::make_pair(data_type_matched_num, format_matched_num);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::pair<int, int> GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
|
|
|
|
|
const std::vector<std::string> &input_formats,
|
|
|
|
|
const std::vector<TypeId> &input_types,
|
|
|
|
|
const std::vector<size_t> &input_not_cnode_indexes) {
|
|
|
|
|
const std::vector<size_t> &input_not_cnode_indexes, bool strict) {
|
|
|
|
|
if (kernel_attr.GetInputSize() != input_types.size()) {
|
|
|
|
|
MS_LOG(DEBUG) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size();
|
|
|
|
|
return std::make_pair(0, 0);
|
|
|
|
@ -98,12 +152,23 @@ std::pair<int, int> GetInputDtypeFormatMatchedNum(const KernelAttr &kernel_attr,
|
|
|
|
|
format_matched_num++;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (is_not_cnode_idx) {
|
|
|
|
|
if (!InputDtypeMatch(kernel_attr.GetInputAttr(i).first, input_types[i], strict)) {
|
|
|
|
|
MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first
|
|
|
|
|
<< ", actual input dtype:" << input_types[i];
|
|
|
|
|
} else {
|
|
|
|
|
data_type_matched_num++;
|
|
|
|
|
}
|
|
|
|
|
format_matched_num++;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (kernel_attr.GetInputAttr(i).first != input_types[i]) {
|
|
|
|
|
MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first
|
|
|
|
|
<< ", actual input dtype:" << input_types[i];
|
|
|
|
|
} else {
|
|
|
|
|
data_type_matched_num++;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (kernel_attr.GetInputAttr(i).second != input_formats[i]) {
|
|
|
|
|
MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second
|
|
|
|
|
<< ", actual input format:" << input_formats[i];
|
|
|
|
@ -141,23 +206,13 @@ void SetKernelBuildInfo(const std::vector<std::string> &input_formats, const std
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), kernel_node);
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void SetKernelInfo(const CNodePtr &kernel_node) {
|
|
|
|
|
std::vector<std::string> input_formats;
|
|
|
|
|
std::vector<TypeId> input_types;
|
|
|
|
|
std::vector<size_t> input_not_cnode_indexes;
|
|
|
|
|
std::vector<std::string> output_formats;
|
|
|
|
|
std::vector<TypeId> output_types;
|
|
|
|
|
MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node);
|
|
|
|
|
GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes);
|
|
|
|
|
auto kernel_attrs =
|
|
|
|
|
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
|
|
|
|
|
if (kernel_attrs.empty()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Operator[" << AnfAlgo::GetCNodeName(kernel_node) << "] is not support.";
|
|
|
|
|
}
|
|
|
|
|
bool SelectKernel(const CNodePtr &kernel_node, KernelAttr *selected_kernel_attr,
|
|
|
|
|
const std::vector<KernelAttr> &kernel_attrs, const std::vector<std::string> &input_formats,
|
|
|
|
|
const std::vector<TypeId> &input_types, const std::vector<size_t> &input_not_cnode_indexes,
|
|
|
|
|
const std::vector<std::string> &infer_output_formats, const std::vector<TypeId> &infer_output_types,
|
|
|
|
|
bool strict) {
|
|
|
|
|
int max_type_matched_num = -1;
|
|
|
|
|
int max_format_matched_num = -1;
|
|
|
|
|
KernelAttr selected_kernel_attr;
|
|
|
|
|
for (auto kernel_attr : kernel_attrs) {
|
|
|
|
|
if (kernel_attr.GetAllSame()) {
|
|
|
|
|
ExpandKernelAttr(kernel_node, &kernel_attr);
|
|
|
|
@ -168,29 +223,61 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
std::pair<int, int> input_type_format_matched_num =
|
|
|
|
|
GetInputDtypeFormatMatchedNum(kernel_attr, input_formats, input_types, input_not_cnode_indexes);
|
|
|
|
|
GetInputDtypeFormatMatchedNum(kernel_attr, input_formats, input_types, input_not_cnode_indexes, strict);
|
|
|
|
|
std::pair<int, int> output_type_format_matched_num =
|
|
|
|
|
GetOutputDtypeFormatMatchedNum(kernel_attr, infer_output_formats, infer_output_types);
|
|
|
|
|
// Data type first
|
|
|
|
|
if (input_type_format_matched_num.first > max_type_matched_num) {
|
|
|
|
|
max_type_matched_num = input_type_format_matched_num.first;
|
|
|
|
|
max_format_matched_num = input_type_format_matched_num.second;
|
|
|
|
|
selected_kernel_attr = kernel_attr;
|
|
|
|
|
*selected_kernel_attr = kernel_attr;
|
|
|
|
|
} else if (input_type_format_matched_num.first == max_type_matched_num &&
|
|
|
|
|
input_type_format_matched_num.second > max_format_matched_num) {
|
|
|
|
|
max_format_matched_num = input_type_format_matched_num.second;
|
|
|
|
|
selected_kernel_attr = kernel_attr;
|
|
|
|
|
*selected_kernel_attr = kernel_attr;
|
|
|
|
|
} else if (input_type_format_matched_num.first == max_type_matched_num &&
|
|
|
|
|
input_type_format_matched_num.second == max_format_matched_num) {
|
|
|
|
|
if (output_type_format_matched_num.first == SizeToInt(infer_output_types.size()) &&
|
|
|
|
|
output_type_format_matched_num.second == SizeToInt(infer_output_types.size())) {
|
|
|
|
|
*selected_kernel_attr = kernel_attr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// All formats and data types matched
|
|
|
|
|
if (max_type_matched_num == SizeToInt(input_types.size()) &&
|
|
|
|
|
max_format_matched_num == SizeToInt(input_types.size())) {
|
|
|
|
|
break;
|
|
|
|
|
max_format_matched_num == SizeToInt(input_types.size()) &&
|
|
|
|
|
output_type_format_matched_num.first == SizeToInt(infer_output_types.size()) &&
|
|
|
|
|
output_type_format_matched_num.second == SizeToInt(infer_output_types.size())) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
void SetKernelInfo(const CNodePtr &kernel_node) {
|
|
|
|
|
std::vector<std::string> input_formats;
|
|
|
|
|
std::vector<TypeId> input_types;
|
|
|
|
|
std::vector<size_t> input_not_cnode_indexes;
|
|
|
|
|
std::vector<std::string> output_formats;
|
|
|
|
|
std::vector<TypeId> output_types;
|
|
|
|
|
std::vector<std::string> infer_output_formats;
|
|
|
|
|
std::vector<TypeId> infer_output_types;
|
|
|
|
|
MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node);
|
|
|
|
|
GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes);
|
|
|
|
|
GetOutputInferFormatsAndDtypes(kernel_node, &infer_output_formats, &infer_output_types);
|
|
|
|
|
auto kernel_attrs =
|
|
|
|
|
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
|
|
|
|
|
if (kernel_attrs.empty()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Operator[" << AnfAlgo::GetCNodeName(kernel_node) << "] is not support.";
|
|
|
|
|
}
|
|
|
|
|
KernelAttr selected_kernel_attr;
|
|
|
|
|
bool matched = true;
|
|
|
|
|
if (!SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types,
|
|
|
|
|
input_not_cnode_indexes, infer_output_formats, infer_output_types, true)) {
|
|
|
|
|
matched = SelectKernel(kernel_node, &selected_kernel_attr, kernel_attrs, input_formats, input_types,
|
|
|
|
|
input_not_cnode_indexes, infer_output_formats, infer_output_types, false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (selected_kernel_attr.GetInputSize() > 0 && ((max_type_matched_num == SizeToInt(input_types.size()) &&
|
|
|
|
|
max_format_matched_num == SizeToInt(input_types.size())) ||
|
|
|
|
|
input_types.size() == input_not_cnode_indexes.size())) {
|
|
|
|
|
MS_LOG(INFO) << "Input format and dtype is matched, max_type_matched_num: " << max_type_matched_num
|
|
|
|
|
<< ", max_format_matched_num: " << max_format_matched_num;
|
|
|
|
|
if (selected_kernel_attr.GetInputSize() > 0 && (matched || input_types.size() == input_not_cnode_indexes.size())) {
|
|
|
|
|
MS_LOG(INFO) << "Input format and dtype is matched";
|
|
|
|
|
GetOutputFormatsAndDtypes(kernel_node, selected_kernel_attr, &output_formats, &output_types);
|
|
|
|
|
UpdatePrevNotCNodeFormatDtype(selected_kernel_attr, input_not_cnode_indexes, kernel_node);
|
|
|
|
|
for (auto &input_index : input_not_cnode_indexes) {
|
|
|
|
|