|
|
|
@ -59,6 +59,7 @@ void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::stri
|
|
|
|
|
TypeId dtype = kTypeUnknown;
|
|
|
|
|
if (IsInputNotCNode(kernel_node, input_index)) {
|
|
|
|
|
input_no_cnode_indexes->emplace_back(input_index);
|
|
|
|
|
dtype = AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index);
|
|
|
|
|
} else {
|
|
|
|
|
dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index);
|
|
|
|
|
}
|
|
|
|
@ -84,22 +85,25 @@ bool IsInputFormatDtypeMatched(const KernelAttr &kernel_attr, const std::vector<
|
|
|
|
|
const std::vector<TypeId> &input_types,
|
|
|
|
|
const std::vector<size_t> &input_not_cnode_indexes) {
|
|
|
|
|
if (kernel_attr.GetInputSize() != input_types.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "Output num is not equal!";
|
|
|
|
|
MS_LOG(ERROR) << "required input num:" << kernel_attr.GetInputSize() << ", actual input num:" << input_types.size();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto input_num = input_types.size();
|
|
|
|
|
for (size_t i = 0; i < input_num; ++i) {
|
|
|
|
|
bool is_not_cnode_idx = std::any_of(input_not_cnode_indexes.begin(), input_not_cnode_indexes.end(),
|
|
|
|
|
[i](size_t index) { return index == i; });
|
|
|
|
|
if (is_not_cnode_idx) {
|
|
|
|
|
bool have_cnode_input = (input_types.size() != input_not_cnode_indexes.size());
|
|
|
|
|
if (have_cnode_input && is_not_cnode_idx) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (kernel_attr.GetInputAttr(i).first != input_types[i]) {
|
|
|
|
|
MS_LOG(ERROR) << "reg dtype=" << kernel_attr.GetInputAttr(i).first << ", input dtype=" << input_types[i];
|
|
|
|
|
MS_LOG(DEBUG) << "required dtype:" << kernel_attr.GetInputAttr(i).first
|
|
|
|
|
<< ", actual input dtype:" << input_types[i];
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (kernel_attr.GetInputAttr(i).second != input_formats[i]) {
|
|
|
|
|
MS_LOG(ERROR) << "reg format=" << kernel_attr.GetInputAttr(i).second << ", input format=" << input_formats[i];
|
|
|
|
|
MS_LOG(DEBUG) << "required format:" << kernel_attr.GetInputAttr(i).second
|
|
|
|
|
<< ", actual input format:" << input_formats[i];
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -114,17 +118,19 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
|
|
|
|
std::vector<std::string> output_formats;
|
|
|
|
|
std::vector<TypeId> output_types;
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "SetKernelInfo, CNode Name: " << AnfAlgo::GetCNodeName(kernel_node);
|
|
|
|
|
GetInputFormatsAndDtypes(kernel_node, &input_formats, &input_types, &input_not_cnode_indexes);
|
|
|
|
|
|
|
|
|
|
auto kernel_attrs =
|
|
|
|
|
kernel::CPUKernelFactory::GetInstance().GetSupportedKernelAttrList(AnfAlgo::GetCNodeName(kernel_node));
|
|
|
|
|
|
|
|
|
|
for (auto &kernel_attr : kernel_attrs) {
|
|
|
|
|
if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) {
|
|
|
|
|
GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types);
|
|
|
|
|
UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node);
|
|
|
|
|
for (size_t index = 0; index < kernel_attrs.size(); ++index) {
|
|
|
|
|
if (IsInputFormatDtypeMatched(kernel_attrs[index], input_formats, input_types, input_not_cnode_indexes)) {
|
|
|
|
|
MS_LOG(INFO) << "Input format and dtype is matched, index: " << index;
|
|
|
|
|
GetOutputFormatsAndDtypes(kernel_node, kernel_attrs[index], &output_formats, &output_types);
|
|
|
|
|
UpdatePrevNotCNodeFormatDtype(kernel_attrs[index], input_not_cnode_indexes, kernel_node);
|
|
|
|
|
for (auto &input_index : input_not_cnode_indexes) {
|
|
|
|
|
input_types[input_index] = kernel_attr.GetInputAttr(input_index).first;
|
|
|
|
|
input_types[input_index] = kernel_attrs[index].GetInputAttr(input_index).first;
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|