diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc index c809618b33..adf5280a8b 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc @@ -383,6 +383,11 @@ bool SetKernelBuilderInputInfo(const std::vector> &inp return false; } + std::vector reshape_type; + if (!StringToAxisVector(input->reshape_type(), &reshape_type)) { + return false; + } + if (param_type == "dynamic") { if (dyn_input_sizes.empty()) { MS_LOG(ERROR) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic"; @@ -394,6 +399,7 @@ bool SetKernelBuilderInputInfo(const std::vector> &inp auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); inputs_device_type.push_back(type_id); inputs_format.push_back(formats[builder_idex]); + reshape_types.push_back(reshape_type); } dyn_input_idx++; } else if (param_type == "required") { @@ -401,6 +407,7 @@ bool SetKernelBuilderInputInfo(const std::vector> &inp auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); inputs_device_type.push_back(type_id); inputs_format.push_back(formats[builder_idex]); + reshape_types.push_back(reshape_type); } else { if (kernel_info_index < real_input_num) { MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is " << kernel_info_index; @@ -408,13 +415,9 @@ bool SetKernelBuilderInputInfo(const std::vector> &inp auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); inputs_device_type.push_back(type_id); inputs_format.push_back(formats[builder_idex]); + reshape_types.push_back(reshape_type); } } - std::vector reshape_type; - if (!StringToAxisVector(input->reshape_type(), &reshape_type)) { - return false; - } - reshape_types.push_back(reshape_type); } builder->SetInputReshapeType(reshape_types); @@ -442,6 +445,11 @@ bool SetKernelBuilderOutputInfo(const std::vector> &ou MS_LOG(WARNING) << "real_output_num: " << real_output_num << ", output_idx: " << output_idx << "is out of limit!"; continue; } + std::vector reshape_type; + if (!StringToAxisVector(output->reshape_type(), &reshape_type)) { + return false; + } + size_t output_num = 0; if (output->param_type() == "dynamic") { if (outputs.size() > 1) { @@ -467,13 +475,9 @@ bool SetKernelBuilderOutputInfo(const std::vector> &ou auto type_id = tbe::DtypeToTypeId(dtypes[builder_idex]); outputs_device_type.push_back(type_id); outputs_format.push_back(formats[builder_idex]); + reshape_types.push_back(reshape_type); output_idx++; } - std::vector reshape_type; - if (!StringToAxisVector(output->reshape_type(), &reshape_type)) { - return false; - } - reshape_types.push_back(reshape_type); } builder->SetOutputReshapeType(reshape_types);