|
|
|
@ -176,9 +176,18 @@ bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vector<Type
|
|
|
|
|
if (inputs_type.size() == 0) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
|
|
|
|
if (input_shape.size() != 4) {
|
|
|
|
|
return false;
|
|
|
|
|
auto inputs_format_position = iter->second.first;
|
|
|
|
|
// If input position is empty, then insert all the input positions, because the input numbers of this op are variable.
|
|
|
|
|
if (inputs_format_position.size() == 0) {
|
|
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); input_index++) {
|
|
|
|
|
inputs_format_position.push_back(input_index);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (const auto &input_format_position : inputs_format_position) {
|
|
|
|
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, input_format_position);
|
|
|
|
|
if (input_shape.size() != 4) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
@ -223,7 +232,7 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vector<TypeI
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list) {
|
|
|
|
|
void SetKernelInfo(const CNodePtr &kernel_node, bool graph_format_transform) {
|
|
|
|
|
std::vector<std::string> inputs_format;
|
|
|
|
|
std::vector<TypeId> inputs_type;
|
|
|
|
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
|
|
|
@ -237,7 +246,7 @@ void SetKernelInfo(const CNodePtr &kernel_node, bool in_black_list) {
|
|
|
|
|
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
|
|
|
|
}
|
|
|
|
|
std::string origin_data_format = kOpFormat_DEFAULT;
|
|
|
|
|
if (!in_black_list && IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
|
|
|
|
|
if (graph_format_transform && IsNeedProcessFormatInfo(kernel_node, inputs_type)) {
|
|
|
|
|
UpdateKernelFormatInfo(kernel_node, inputs_type, &inputs_format, &outputs_format, &origin_data_format);
|
|
|
|
|
}
|
|
|
|
|
std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> builder =
|
|
|
|
|