|
|
|
@ -71,9 +71,6 @@ void GetInputFormatsAndDtypes(const CNodePtr &kernel_node, std::vector<std::stri
|
|
|
|
|
void GetOutputFormatsAndDtypes(const CNodePtr &kernel_node, const KernelAttr &kernel_attr,
|
|
|
|
|
std::vector<std::string> *output_formats, std::vector<TypeId> *output_types) {
|
|
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
|
|
|
|
if (kernel_attr.GetOutputSize() != output_num) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Output num is not equal!";
|
|
|
|
|
}
|
|
|
|
|
for (size_t output_index = 0; output_index < output_num; ++output_index) {
|
|
|
|
|
output_formats->emplace_back(kernel_attr.GetOutputAttr(output_index).second);
|
|
|
|
|
auto dtype = kernel_attr.GetOutputAttr(output_index).first;
|
|
|
|
@ -145,6 +142,11 @@ void SetKernelInfo(const CNodePtr &kernel_node) {
|
|
|
|
|
ExpandKernelAttr(kernel_node, &kernel_attr);
|
|
|
|
|
}
|
|
|
|
|
if (IsInputFormatDtypeMatched(kernel_attr, input_formats, input_types, input_not_cnode_indexes)) {
|
|
|
|
|
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
|
|
|
|
if (kernel_attr.GetOutputSize() != output_num) {
|
|
|
|
|
MS_LOG(DEBUG) << "Output num is not equal!";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Input format and dtype is matched, index: " << index;
|
|
|
|
|
GetOutputFormatsAndDtypes(kernel_node, kernel_attr, &output_formats, &output_types);
|
|
|
|
|
UpdatePrevNotCNodeFormatDtype(kernel_attr, input_not_cnode_indexes, kernel_node);
|
|
|
|
|