From 518818fbef099653aaf212a939ff62b5b842fec6 Mon Sep 17 00:00:00 2001 From: liubuyu Date: Mon, 8 Mar 2021 14:30:09 +0800 Subject: [PATCH] reshape type for 3d nodes --- .../kernel_compiler/kernel_build_info.cc | 14 +- .../kernel_compiler/kernel_build_info.h | 20 +-- .../tbe_kernel_select/tbe_kernel_select.cc | 31 +--- .../tbe/tbe_kernel_select/tbe_kernel_select.h | 3 +- .../backend/optimizer/ascend/ascend_helper.cc | 20 +-- .../backend/optimizer/ascend/ascend_helper.h | 4 +- .../backend/session/anf_runtime_algorithm.cc | 10 +- .../backend/session/anf_runtime_algorithm.h | 6 +- mindspore/ccsrc/common/trans.cc | 132 +++++++++++++++++- mindspore/ccsrc/common/trans.h | 15 +- .../device/ascend/ascend_device_address.cc | 7 +- .../ccsrc/runtime/device/kernel_runtime.cc | 2 +- mindspore/ccsrc/utils/utils.h | 1 + mindspore/core/ir/tensor.h | 6 +- .../format_type/insert_trans_op_test.cc | 12 +- .../remove_internal_output_test.cc | 12 +- .../ascend/ir_fission/transdata_split_test.cc | 24 ++-- .../transpose_transdata_fusion_test.cc | 12 +- .../pass/eliminate_redundant_op_test.cc | 16 +-- 19 files changed, 227 insertions(+), 120 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc index 2932e90064..526e7a8412 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc @@ -66,7 +66,7 @@ size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); } size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); } -std::vector KernelBuildInfo::GetInputReshapeType(size_t input_index) const { +std::string KernelBuildInfo::GetInputReshapeType(size_t input_index) const { if (input_reshape_type_.empty()) { return {}; } @@ -77,7 +77,7 @@ std::vector KernelBuildInfo::GetInputReshapeType(size_t input_index) const return input_reshape_type_[input_index]; } -std::vector KernelBuildInfo::GetOutputReshapeType(size_t output_index) const { +std::string KernelBuildInfo::GetOutputReshapeType(size_t output_index) const { if (output_reshape_type_.empty()) { return {}; } @@ -175,14 +175,13 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor) std::shared_ptr KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; } -void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType( - const std::vector> &input_reshape_type) { +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType(const std::vector &input_reshape_type) { MS_EXCEPTION_IF_NULL(kernel_build_info_); kernel_build_info_->input_reshape_type_ = input_reshape_type; } void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsReshapeType( - const std::vector> &output_reshape_type) { + const std::vector &output_reshape_type) { MS_EXCEPTION_IF_NULL(kernel_build_info_); kernel_build_info_->output_reshape_type_ = output_reshape_type; } @@ -206,8 +205,7 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string } kernel_build_info_->outputs_format_[index] = format; } -void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::vector &input_reshape_type, - size_t index) { +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::string &input_reshape_type, size_t index) { if (index >= kernel_build_info_->input_reshape_type_.size()) { MS_LOG(EXCEPTION) << "index outof range!"; } @@ -215,7 +213,7 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::vec std::back_inserter(kernel_build_info_->input_reshape_type_[index])); } -void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::vector &output_reshape_type, +void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::string &output_reshape_type, size_t index) { if (index >= kernel_build_info_->output_reshape_type_.size()) { MS_LOG(EXCEPTION) << "index outof range!"; diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h index aa8496196b..30d69aae23 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h @@ -57,13 +57,13 @@ class KernelBuildInfo { TypeId GetOutputDeviceType(size_t output_index) const; - std::vector GetInputReshapeType(size_t input_index) const; + std::string GetInputReshapeType(size_t input_index) const; bool IsInputDefaultPadding() const; bool IsOutputDefaultPadding() const; - std::vector GetOutputReshapeType(size_t input_index) const; + std::string GetOutputReshapeType(size_t input_index) const; const std::string &GetOriginDataFormat() const; @@ -75,9 +75,9 @@ class KernelBuildInfo { const std::vector &GetAllOutputDeviceTypes() const; - std::vector> GetAllOutputReshapeType() const; + std::vector GetAllOutputReshapeType() const; - std::vector> GetAllInputReshapeType() const; + std::vector GetAllInputReshapeType() const; OpPattern op_pattern() const { return op_pattern_; } @@ -106,8 +106,8 @@ class KernelBuildInfo { std::vector inputs_format_; OpPattern op_pattern_; std::vector outputs_format_; - std::vector> input_reshape_type_; - std::vector> output_reshape_type_; + std::vector input_reshape_type_; + std::vector output_reshape_type_; std::vector inputs_device_type_; std::vector outputs_device_type_; FusionType fusion_type_; @@ -151,9 +151,9 @@ class KernelBuildInfo::KernelBuildInfoBuilder { void SetOutputsDeviceType(const std::vector &outputs_device_type); - void SetInputsReshapeType(const std::vector> &input_reshape_type); + void SetInputsReshapeType(const std::vector &input_reshape_type); - void SetOutputsReshapeType(const std::vector> &output_reshape_type); + void SetOutputsReshapeType(const std::vector &output_reshape_type); void SetFusionType(FusionType fusion_type); @@ -165,9 +165,9 @@ class KernelBuildInfo::KernelBuildInfoBuilder { void SetOutputFormat(const std::string &format, size_t index); - void SetInputReshapeType(const std::vector &input_reshape_type, size_t index); + void SetInputReshapeType(const std::string &input_reshape_type, size_t index); - void SetOutputReshapeType(const std::vector &output_reshape_type, size_t index); + void SetOutputReshapeType(const std::string &output_reshape_type, size_t index); void SetInputDeviceType(const TypeId &input_device_type, size_t index); diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc index b957e3974d..cb55d657e8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc @@ -99,7 +99,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { SetTbeBuildCommonInfo(op_info, &builder); std::vector inputs_format; std::vector inputs_device_type; - std::vector> inputs_reshape_type; + std::vector inputs_reshape_type; // input if (!GenBuilderItem(true, kernel_build_info_index, real_input_tensor_num, inputs_info, dyn_input_sizes, &inputs_format, &inputs_device_type, &inputs_reshape_type)) { @@ -111,7 +111,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { // output std::vector outputs_format; std::vector outputs_device_type; - std::vector> outputs_reshape_type; + std::vector outputs_reshape_type; if (!GenBuilderItem(false, kernel_build_info_index, real_output_tensor_num, outputs_info, dyn_input_sizes, &outputs_format, &outputs_device_type, &outputs_reshape_type)) { break; @@ -290,7 +290,7 @@ std::vector TbeKernelSelect::GetNodeDynamicInputs() { bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, const std::vector> &ios_info, const std::vector &dyn_input_sizes, std::vector *formats, - std::vector *device_types, std::vector> *reshape_types) { + std::vector *device_types, std::vector *reshape_types) { MS_EXCEPTION_IF_NULL(formats); MS_EXCEPTION_IF_NULL(device_types); MS_EXCEPTION_IF_NULL(reshape_types); @@ -306,8 +306,7 @@ bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_ind kernel_build_info_format = io_info_item->formats()[kernel_build_info_index]; } const std::string &io_param_type = io_info_item->param_type(); - std::vector reshape_type; - StringToAxisVector(io_info_item->reshape_type(), &reshape_type); + auto reshape_type = io_info_item->reshape_type(); if (io_param_type == kParamTypeDynamic) { // dynamic io if (is_input) { @@ -355,28 +354,6 @@ bool TbeKernelSelect::GenBuilderItem(bool is_input, size_t kernel_build_info_ind return true; } -void TbeKernelSelect::StringToAxisVector(const std::string &reshape_type_str, std::vector *reshape_type_vec) { - MS_EXCEPTION_IF_NULL(reshape_type_vec); - for (const auto &c : reshape_type_str) { - switch (c) { - case 'N': - reshape_type_vec->push_back(N); - break; - case 'C': - reshape_type_vec->push_back(C); - break; - case 'H': - reshape_type_vec->push_back(H); - break; - case 'W': - reshape_type_vec->push_back(W); - break; - default: - MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type."; - } - } -} - void TbeKernelSelect::CreateNewOpIOInfo(const mindspore::kernel::OpIOInfo &op_io_info, const std::vector> &support_format_item, size_t index, mindspore::kernel::OpIOInfo *op_io_info_new) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h index 2547260ad2..84427aa306 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h @@ -52,8 +52,7 @@ class TbeKernelSelect { bool GenBuilderItem(bool is_input, size_t kernel_build_info_index, size_t real_io_tensor_num, const std::vector> &ios_info, const std::vector &dyn_input_sizes, std::vector *formats, - std::vector *device_types, std::vector> *reshape_types); - static void StringToAxisVector(const std::string &reshape_type_str, std::vector *reshape_type_vec); + std::vector *device_types, std::vector *reshape_types); static void CreateNewOpInfo(const OpInfo &op_info, const SupportFormat &support_format, OpInfo *op_info_new); static void CreateNewOpIOInfo(const OpIOInfo &op_io_info, const std::vector> &support_format_item, size_t index, diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index 7a314812b2..7bc76b4569 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -187,8 +187,8 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast(), insert_index) : node; std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index); std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format; - std::vector padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index) - : AnfAlgo::GetOutputReshapeType(node, insert_index); + std::string padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index) + : AnfAlgo::GetOutputReshapeType(node, insert_index); auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index) : AnfAlgo::GetOutputInferShape(input_node, insert_index); bool need_padding = is_insert_input ? trans::IsNeedPadding(dst_format, input_node_out_shape.size()) @@ -200,8 +200,8 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt } else if (is_insert_input) { // if need padding & is input need insert a transdata // reshape[padding shape] -> transdata[padding shape] -> node - auto padding_shape = - trans::PaddingShapeTo4d(input_node_out_shape, AnfAlgo::GetInputReshapeType(node, insert_index)); + auto padding_shape = trans::PaddingShape(input_node_out_shape, AnfAlgo::GetInputFormat(node, insert_index), + AnfAlgo::GetInputReshapeType(node, insert_index)); auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); trans_node = trans_data; @@ -222,8 +222,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt } void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, - const AnfNodePtr &trans_data, const std::vector &reshape_type, - const TypeId &type_id) { + const AnfNodePtr &trans_data, const std::string &reshape_type, const TypeId &type_id) { MS_EXCEPTION_IF_NULL(trans_data); auto ori_build_info = AnfAlgo::GetSelectKernelBuildInfo(trans_data); MS_EXCEPTION_IF_NULL(ori_build_info); @@ -249,9 +248,10 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, if (need_padding) { // if need padding we should set the transdata node's shape to the padding shape auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, - {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)}, - trans_node.get()); + AnfAlgo::SetOutputInferTypeAndShape( + {AnfAlgo::GetOutputInferDataType(input, 0)}, + {trans::PaddingShape(AnfAlgo::GetOutputInferShape(input, 0), AnfAlgo::GetOutputFormat(input, 0), padding_axis)}, + trans_node.get()); } else { AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); @@ -273,7 +273,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, const TypeId &input_type, const TypeId &output_type, const std::vector &origin_shape, const TypeId &origin_type, - const std::vector &reshape_type) { + const std::string &reshape_type) { MS_EXCEPTION_IF_NULL(func_graph); std::string input_format = format; std::string output_format = format; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h index 872d024810..5c2d497071 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.h @@ -88,7 +88,7 @@ class OpFinder { using OpFinderPtr = std::shared_ptr; void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, - const AnfNodePtr &trans_data, const std::vector &reshape_type = {}, + const AnfNodePtr &trans_data, const std::string &reshape_type = {""}, const TypeId &type_id = kTypeUnknown); CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, @@ -97,7 +97,7 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, const TypeId &input_type, const TypeId &output_type, const std::vector &origin_shape, const TypeId &origin_type, - const std::vector &reshape_type = std::vector{}); + const std::string &reshape_type = std::string{}); AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const KernelSelectPtr &kernel_select); diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 41e92b6de4..6c7997194b 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -586,7 +586,7 @@ std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_n return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); } -std::vector AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) { +std::string AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) { KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second); } @@ -642,7 +642,7 @@ std::vector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr & } // if format is default_format or NC1KHKWHWC0,device shape = original shape if (trans::IsNeedPadding(format, infer_shape.size())) { - infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx)); + infer_shape = trans::PaddingShape(infer_shape, format, GetOutputReshapeType(node, output_idx)); } return trans::TransShapeToDevice(infer_shape, format); } @@ -655,12 +655,12 @@ std::vector AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &n } // if format is default_format or NC1KHKWHWC0,device shape = original shape if (trans::IsNeedPadding(format, infer_shape.size())) { - infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); + infer_shape = trans::PaddingShape(infer_shape, format, GetInputReshapeType(node, input_idx)); } return trans::TransShapeToDevice(infer_shape, format); } -std::vector AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { +std::string AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { MS_EXCEPTION_IF_NULL(node); if (input_idx > GetInputTensorNum(node)) { MS_LOG(EXCEPTION) << "The index:" << input_idx @@ -681,7 +681,7 @@ std::vector AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &nod return build_info->GetInputReshapeType(input_idx); } -std::vector AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) { +std::string AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); if (output_idx > GetOutputTensorNum(node)) { MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 3ea2c92f3c..01aac59ee6 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -122,7 +122,7 @@ class AnfRuntimeAlgorithm { // get output format from prev node,input_index is the input index of current node related to prev node static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx); // get reshape_type of from the output of input node. - static std::vector GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx); + static std::string GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx); // get output shapes inferred by ME from input nodes. static std::vector GetOutputInferShape(const AnfNodePtr &node, size_t output_idx); // get input shapes inferred by ME from input nodes. @@ -132,9 +132,9 @@ class AnfRuntimeAlgorithm { // get input shapes which will built and run in device static std::vector GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx); // Get Input Padding Axis - static std::vector GetInputReshapeType(const AnfNodePtr &node, size_t output_idx); + static std::string GetInputReshapeType(const AnfNodePtr &node, size_t output_idx); // Get Output Padding Axis - static std::vector GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx); + static std::string GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx); // get output data type inferred by ME of anf node static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx); // get output original data type from prev node,input_index is the input index of current node related to prev node diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 9532b6fb52..497b29e278 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -21,6 +21,7 @@ #include "abstract/utils.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/kernel_compiler/kernel.h" +#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h" #include "runtime/device/convert_tensor_utils.h" #include "utils/convert_utils.h" #include "utils/log_adapter.h" @@ -28,7 +29,7 @@ namespace mindspore { namespace trans { -enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNdhwc }; +enum kAxis : int { kN = 0, kC, kH, kW, kNchwDims, kNcdhw }; inline void SetData(size_t size, bool pad_zero, size_t src_idx, size_t dst_idx, const FormatArgs &args, void *result) { switch (size) { case 1: @@ -343,7 +344,7 @@ std::vector Nc1hwc04DeviceShape(const std::vector &shape) { } std::vector NcdhwDeviceShape(const std::vector &shape) { - if (shape.size() < kNdhwc) { + if (shape.size() < kNcdhw) { MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc."; } return shape; @@ -388,6 +389,20 @@ bool IsNeedPadding(const std::string &format, const size_t shape_size) { return false; } +std::vector PaddingShape(const std::vector &shape, const std::string &format, + const std::string &pad_index) { + std::vector host_shape; + if (k3DFormatSet.find(format) != k3DFormatSet.end()) { + if (shape.size() >= kNcdhw) { + return shape; + } + host_shape = trans::PaddingShapeTo5d(shape, pad_index); + } else { + host_shape = trans::PaddingShapeTo4d(shape, pad_index); + } + return host_shape; +} + ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) { MS_EXCEPTION_IF_NULL(node); ShapeVector shape; @@ -409,14 +424,84 @@ ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) { } else { host_shape = AnfAlgo::GetOutputInferShape(node, index); } - if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, index), host_shape.size())) { - host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, index)); + auto format = AnfAlgo::GetOutputFormat(node, index); + if (trans::IsNeedPadding(format, host_shape.size())) { + host_shape = trans::PaddingShape(host_shape, format, AnfAlgo::GetOutputReshapeType(node, index)); } std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToLong); return shape; } -std::vector PaddingShapeTo4d(const std::vector &shape, const std::vector &padding_axis) { +void StringToAxisVector4D(const std::string &reshape_type_str, std::vector *reshape_type_vec) { + MS_EXCEPTION_IF_NULL(reshape_type_vec); + if (reshape_type_str.empty()) { + MS_LOG(DEBUG) << "Reshape type str is empty, no need padding."; + return; + } + for (const auto &c : reshape_type_str) { + switch (c) { + case 'N': + reshape_type_vec->push_back(N); + break; + case 'C': + reshape_type_vec->push_back(C); + break; + case 'H': + reshape_type_vec->push_back(H); + break; + case 'W': + reshape_type_vec->push_back(W); + break; + default: + MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type."; + } + } +} +void StringToAxisVector5D(const std::string &reshape_type_str, std::vector *reshape_type_vec) { + MS_EXCEPTION_IF_NULL(reshape_type_vec); + if (reshape_type_str.empty()) { + MS_LOG(DEBUG) << "Reshape type str is empty, no need padding."; + return; + } + for (const auto &c : reshape_type_str) { + switch (c) { + case 'N': + reshape_type_vec->push_back(N_ncdhw); + break; + case 'C': + reshape_type_vec->push_back(C_ncdhw); + break; + case 'D': + reshape_type_vec->push_back(D_ncdhw); + break; + case 'H': + reshape_type_vec->push_back(H_ncdhw); + break; + case 'W': + reshape_type_vec->push_back(W_ncdhw); + break; + default: + MS_LOG(EXCEPTION) << "Unknown axis " << c << "in reshape type."; + } + } +} + +std::vector PaddingShapeTo5d(const std::vector &shape, const std::string &padding_str) { + std::vector padding_axis; + StringToAxisVector5D(padding_str, &padding_axis); + if (padding_axis.empty() || shape.size() != padding_axis.size()) { + return PaddingShapeTo5dDefault(shape); + } + std::vector shape_5d(kNcdhw, 1); + for (size_t index = 0; index < padding_axis.size(); index++) { + shape_5d[padding_axis[index]] = shape[index]; + } + return shape_5d; +} + +std::vector PaddingShapeTo4d(const std::vector &shape, const std::string &padding_str) { + std::vector padding_axis; + StringToAxisVector4D(padding_str, &padding_axis); if (padding_axis.empty() || shape.size() != padding_axis.size()) { return PaddingShapeTo4dByDefault(shape); } @@ -427,6 +512,38 @@ std::vector PaddingShapeTo4d(const std::vector &shape, const std return shape_4d; } +std::vector PaddingShapeTo5dDefault(const std::vector &shape) { + if (shape.size() >= kNcdhw) { + return shape; + } + std::vector shape_5d(kNcdhw, 1); + switch (shape.size()) { + case 0: + return shape_5d; + case 1: + shape_5d[1] = shape[0]; + break; + case 2: + shape_5d[1] = shape[0]; + shape_5d[2] = shape[1]; + break; + case 3: + shape_5d[1] = shape[0]; + shape_5d[2] = shape[1]; + shape_5d[3] = shape[2]; + break; + case 4: + shape_5d[1] = shape[0]; + shape_5d[2] = shape[1]; + shape_5d[3] = shape[2]; + shape_5d[4] = shape[3]; + break; + default: + MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size(); + } + return shape_5d; +} + std::vector TransShapeToDevice(const std::vector &shape, const std::string &format) { using DeviceShapeTransfer = std::function(const std::vector &)>; const std::map device_shape_map{{kOpFormat_NCHW, NchwDeviceShape}, @@ -475,10 +592,13 @@ std::vector TransShapeToDevice(const std::vector &shape, const s device_shape.push_back(kCubeSize); return device_shape; } - if (shape.size() != kNchwDims && shape.size() != 5) { + if (shape.size() != kNchwDims && k3DFormatSet.find(format) == k3DFormatSet.end()) { MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; temp_shape = PaddingShapeTo4dByDefault(shape); } + if (shape.size() != kNcdhw && k3DFormatSet.find(format) != k3DFormatSet.end()) { + temp_shape = PaddingShapeTo5dDefault(shape); + } auto iter = device_shape_map.find(format); if (iter == device_shape_map.end()) { MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]"; diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index 9014f3c051..d5ea814f29 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -30,6 +30,13 @@ namespace mindspore { namespace trans { +enum Axis5D : int { + N_ncdhw = 0, + C_ncdhw, + D_ncdhw, + H_ncdhw, + W_ncdhw, +}; struct TypeIdArgs { const void *data; size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d @@ -50,7 +57,13 @@ struct FormatArgs { size_t CubeSizeByType(const TypeId data_type); -std::vector PaddingShapeTo4d(const std::vector &shape, const std::vector &padding_axis = {}); +std::vector PaddingShape(const std::vector &shape, const std::string &format, + const std::string &pad_index = {""}); +std::vector PaddingShapeTo4d(const std::vector &shape, const std::string &padding_axis = {""}); +std::vector PaddingShapeTo5d(const std::vector &shape, const std::string &padding_axis = {""}); +std::vector PaddingShapeTo5dDefault(const std::vector &shape); +void StringToAxisVector4D(const std::string &reshape_type_str, std::vector *reshape_type_vec); +void StringToAxisVector5D(const std::string &reshape_type_str, std::vector *reshape_type_vec); ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index); bool IsNeedPadding(const std::string &format, const size_t shape_size); std::vector TransShapeToDevice(const std::vector &shape, const std::string &format); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index 7d3bc87372..b2aab5f15f 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -475,7 +475,7 @@ std::vector AscendDeviceAddress::GetDeviceShape(std::vector *hos device_shape = trans::TransShapeToDevice(*host_shape, format_); } else { if (host_shape_.empty()) { - *host_shape = trans::PaddingShapeTo4d(*host_shape); + *host_shape = trans::PaddingShape(*host_shape, format_); } else { host_shape->clear(); (void)std::transform(host_shape_.begin(), host_shape_.end(), std::back_inserter(*host_shape), LongToSize); @@ -595,11 +595,10 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh host_shape.emplace_back(1); } std::vector device_shape; - if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW || format_ == kOpFormat_NDC1HWC0 || - format_ == kOpFormat_FRACTAL_Z_3D) { + if (format_ == kOpFormat_FRAC_NZ) { device_shape = trans::TransShapeToDevice(host_shape, format_); } else { - host_shape = trans::PaddingShapeTo4d(host_shape); + host_shape = trans::PaddingShape(host_shape, format_); device_shape = trans::TransShapeToDevice(host_shape, format_); } if (type_id_ != type) { diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 96765ccc77..ebff591dca 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -68,7 +68,7 @@ size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &nod std::vector shape = AnfAlgo::GetOutputDeviceShape(node, output_index); auto format = AnfAlgo::GetOutputFormat(node, output_index); if (shape.empty() && format != kOpFormat_DEFAULT) { - shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index)); + shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index)); shape = trans::TransShapeToDevice(shape, format); } // scalar's output shape is a empty vector diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 37804d3848..3519733367 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -303,6 +303,7 @@ constexpr auto kAttrFactor = "factor"; constexpr auto kAttrIsRef = "isRef"; constexpr auto kAttrDataShape = "data_shape"; constexpr auto kAttrFormat = "format"; +constexpr auto kAttrReshapeType = "reshape_type"; constexpr auto kAttrAxis = "axis"; constexpr auto kAttrKeepDims = "keep_dims"; constexpr auto kAttrShapeGamma = "shape_gamma"; diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index 04b51e4f23..c8c1a60d79 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -285,8 +285,8 @@ class Tensor : public MetaTensor { DeviceSyncPtr device_address() const { return device_sync_; } void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; } - void set_padding_type(std::vector padding_type) { padding_type_ = padding_type; } - std::vector padding_type() const { return padding_type_; } + void set_padding_type(const std::string padding_type) { padding_type_ = padding_type; } + std::string padding_type() const { return padding_type_; } std::string id() const { return id_; } TypePtr cast_dtype() { return cast_dtype_; } @@ -366,7 +366,7 @@ class Tensor : public MetaTensor { bool cache_enable_{false}; std::shared_ptr cache_tensor_ptr_{nullptr}; std::shared_ptr hashmap_tensor_ptr_{nullptr}; - std::vector padding_type_; + std::string padding_type_{""}; TypePtr cast_dtype_{nullptr}; std::shared_ptr device_event_{nullptr}; }; diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc index 35f0409e93..fa79554a22 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc @@ -50,8 +50,8 @@ class TestHWInsertTransOp : public BackendCommon { KernelBuildInfoBuilder builder; builder.SetInputsFormat({format, format}); builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); - builder.SetInputsReshapeType({{},{}}); - builder.SetOutputsReshapeType({}); + builder.SetInputsReshapeType({"", ""}); + builder.SetOutputsReshapeType({""}); builder.SetOutputsFormat({format}); builder.SetOutputsDeviceType({kFloat16->type_id()}); add->set_kernel_info(std::make_shared()); @@ -72,8 +72,8 @@ class TestHWInsertTransOp : public BackendCommon { EXPECT_NE(ret->input(1)->cast()->input(1)->cast()->input(1), nullptr); auto max_pool = ret->input(1)->cast()->input(1)->cast()->input(1); KernelBuildInfoBuilder builder; - builder.SetInputsReshapeType({{}}); - builder.SetOutputsReshapeType({{},{}}); + builder.SetInputsReshapeType({""}); + builder.SetOutputsReshapeType({"", ""}); builder.SetInputsFormat({kOpFormat_DEFAULT}); builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({format, format}); @@ -92,8 +92,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { ~MockInsertTransOpKernelSelectTrans4Dto5D() override = default; void SelectKernel(const CNodePtr &cnode) override { KernelBuildInfoBuilder builder; - builder.SetInputsReshapeType({{}}); - builder.SetOutputsReshapeType({{}}); + builder.SetInputsReshapeType({""}); + builder.SetOutputsReshapeType({""}); builder.SetInputsFormat({"NCHW"}); builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc index 14cbb8eab0..7be62aa5c2 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc @@ -53,8 +53,8 @@ class TestHWRemoveInternalOutput : public BackendCommon { KernelBuildInfoBuilder builder; builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()}); - builder.SetInputsReshapeType({{}, {}}); - builder.SetOutputsReshapeType({{}}); + builder.SetInputsReshapeType({"", ""}); + builder.SetOutputsReshapeType({""}); builder.SetOutputsFormat({kOpFormat_NC1HWC0}); builder.SetOutputsDeviceType({kFloat16->type_id()}); add->set_kernel_info(std::make_shared()); @@ -80,8 +80,8 @@ class TestHWRemoveInternalOutput : public BackendCommon { kg->AddInternalOutput(tuple_getitem1, max_pool, 0, true); kg->AddInternalOutput(tuple_getitem2, max_pool, 1, true); KernelBuildInfoBuilder builder; - builder.SetInputsReshapeType({{}}); - builder.SetOutputsReshapeType({{}, {}}); + builder.SetInputsReshapeType({""}); + builder.SetOutputsReshapeType({"", ""}); builder.SetInputsFormat({kOpFormat_DEFAULT}); builder.SetInputsDeviceType({kFloat32->type_id()}); builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); @@ -103,8 +103,8 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({kOpFormat_DEFAULT}); builder.SetOutputsDeviceType({kFloat32->type_id()}); - builder.SetInputsReshapeType({{}}); - builder.SetOutputsReshapeType({{}}); + builder.SetInputsReshapeType({""}); + builder.SetOutputsReshapeType({""}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } }; diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc index 91c2a7c119..7e815449da 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc @@ -51,8 +51,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); - builder.SetInputsReshapeType({}); - builder.SetOutputsReshapeType({}); + builder.SetInputsReshapeType({""}); + builder.SetOutputsReshapeType({""}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } else { KernelBuildInfoBuilder builder; @@ -60,8 +60,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); - builder.SetInputsReshapeType({}); - builder.SetOutputsReshapeType({}); + builder.SetInputsReshapeType({""}); + builder.SetOutputsReshapeType({""}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } @@ -79,8 +79,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NCHW"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); - builder.SetInputsReshapeType({{}}); - builder.SetOutputsReshapeType({{}}); + builder.SetInputsReshapeType({""}); + builder.SetOutputsReshapeType({""}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } else { KernelBuildInfoBuilder builder; @@ -88,8 +88,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NCHW"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); - builder.SetInputsReshapeType({{}}); - builder.SetOutputsReshapeType({{}}); + builder.SetInputsReshapeType({""}); + builder.SetOutputsReshapeType({""}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } } @@ -125,8 +125,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) { builder.SetKernelType(KernelType::TBE_KERNEL); builder.SetFusionType(kernel::FusionType::ELEMWISE); builder.SetProcessor(kernel::Processor::AICORE); - builder.SetInputsReshapeType({{}}); - builder.SetOutputsReshapeType({{}}); + builder.SetInputsReshapeType({""}); + builder.SetOutputsReshapeType({""}); auto kernel_info = std::make_shared(); kernel_info->set_select_kernel_build_info(builder.Build()); transpose->set_kernel_info(kernel_info); @@ -173,8 +173,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_nchw_fraz) { builder.SetKernelType(KernelType::TBE_KERNEL); builder.SetFusionType(kernel::FusionType::ELEMWISE); builder.SetProcessor(kernel::Processor::AICORE); - builder.SetInputsReshapeType({{}}); - builder.SetOutputsReshapeType({{}}); + builder.SetInputsReshapeType({""}); + builder.SetOutputsReshapeType({""}); auto kernel_info = std::make_shared(); kernel_info->set_select_kernel_build_info(builder.Build()); transpose->set_kernel_info(kernel_info); diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc index 2929fdfb30..504499d20e 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc @@ -58,8 +58,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); - builder.SetInputsReshapeType({}); - builder.SetOutputsReshapeType({}); + builder.SetInputsReshapeType({""}); + builder.SetOutputsReshapeType({""}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } else { KernelBuildInfoBuilder builder; @@ -67,8 +67,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); - builder.SetInputsReshapeType({}); - builder.SetOutputsReshapeType({}); + builder.SetInputsReshapeType({""}); + builder.SetOutputsReshapeType({""}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } } @@ -97,8 +97,8 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) { EXPECT_NE(transpose, nullptr); KernelBuildInfoBuilder builder; - builder.SetInputsReshapeType({}); - builder.SetOutputsReshapeType({}); + builder.SetInputsReshapeType({""}); + builder.SetOutputsReshapeType({""}); builder.SetInputsFormat({"NCHW"}); builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); diff --git a/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc b/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc index bad4365799..867be93334 100644 --- a/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc +++ b/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc @@ -56,8 +56,8 @@ class MockEliminate5To4And4To5KernelSelect : public KernelSelect { ~MockEliminate5To4And4To5KernelSelect() override = default; void SelectKernel(const CNodePtr &cnode) override { KernelBuildInfoBuilder builder; - builder.SetInputsReshapeType({{}}); - builder.SetOutputsReshapeType({{}}); + builder.SetInputsReshapeType({""}); + builder.SetOutputsReshapeType({""}); builder.SetInputsFormat({"NCHW"}); builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); @@ -104,8 +104,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) { builder.SetOutputsFormat({"NC1HWC0"}); builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()}); - builder.SetInputsReshapeType({{}, {}}); - builder.SetOutputsReshapeType({{}}); + builder.SetInputsReshapeType({"", ""}); + builder.SetOutputsReshapeType({""}); sub->set_kernel_info(std::make_shared()); add->set_kernel_info(std::make_shared()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get()); @@ -171,8 +171,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast) { builder.SetOutputsFormat({"NC1HWC0"}); builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()}); - builder.SetInputsReshapeType({{}, {}}); - builder.SetOutputsReshapeType({{}, {}}); + builder.SetInputsReshapeType({"", ""}); + builder.SetOutputsReshapeType({"", ""}); sub->set_kernel_info(std::make_shared()); add->set_kernel_info(std::make_shared()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get()); @@ -248,8 +248,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast_depend_cast) { builder.SetOutputsFormat({"NC1HWC0"}); builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()}); - builder.SetInputsReshapeType({{}, {}}); - builder.SetOutputsReshapeType({{}}); + builder.SetInputsReshapeType({"", ""}); + builder.SetOutputsReshapeType({""}); sub->set_kernel_info(std::make_shared()); add->set_kernel_info(std::make_shared()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());