|
|
|
@ -109,6 +109,21 @@ bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector<
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ExpandKernelAttr(const CNodePtr &kernel_node, KernelAttr *kernel_attr) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_attr);
|
|
|
|
|
TypeId input_dtype = kernel_attr->GetInputAttr(0).first;
|
|
|
|
|
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
|
|
|
|
for (size_t i = 1; i < input_num; ++i) {
|
|
|
|
|
kernel_attr->AddInputAttr(input_dtype);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
TypeId output_dtype = kernel_attr->GetOutputAttr(0).first;
|
|
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
|
|
|
|
for (size_t i = 1; i < output_num; ++i) {
|
|
|
|
|
kernel_attr->AddOutputAttr(output_dtype);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void SetKernelInfo(const CNodePtr &kernel_node) {
|
|
|
|
@ -125,12 +140,16 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
|
|
|
|
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
|
|
|
|
|
|
|
|
|
|
for (size_t index = 0; index < kernel_attrs.size(); ++index) {
|
|
|
|
|
if (IsInputFormatDtypeMatched(kernel_attrs[index], input_formats, input_types, input_not_cnode_indexes)) {
|
|
|
|
|
auto kernel_attr = kernel_attrs[index];
|
|
|
|
|
if (kernel_attr.GetAllSame()) {
|
|
|
|
|
ExpandKernelAttr(kernel_node, &kernel_attr);
|
|
|
|
|
}
|
|
|
|
|
if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) {
|
|
|
|
|
MS_LOG(INFO) << "Input format and dtype is matched, index: " << index;
|
|
|
|
|
GetOutputFormatsAndDtypes(kernel_node, kernel_attrs[index], &output_formats, &output_types);
|
|
|
|
|
UpdatePrevNotCNodeFormatDtype(kernel_attrs[index], input_not_cnode_indexes, kernel_node);
|
|
|
|
|
GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types);
|
|
|
|
|
UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node);
|
|
|
|
|
for (auto &input_index : input_not_cnode_indexes) {
|
|
|
|
|
input_types[input_index] = kernel_attrs[index].GetInputAttr(input_index).first;
|
|
|
|
|
input_types[input_index] = kernel_attr.GetInputAttr(input_index).first;
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|