diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc index 1f844c72c4..a5d005d540 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc @@ -95,6 +95,16 @@ constexpr auto kJSocVersion = "socVersion"; constexpr auto kSOC_VERSION = "SOC_VERSION"; constexpr auto kJIsDynamicShape = "is_dynamic_shape"; +bool IsNeedChangeDefaultFormat(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + MS_LOG(INFO) << "Check if need change default format"; + if (AnfAlgo::HasNodeAttr("io_format", cnode->cast())) { + auto attr = AnfAlgo::GetNodeAttr(cnode, "io_format"); + return attr == kOpFormat_NCDHW; + } + return false; +} + bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr &anf_node, nlohmann::json *kernel_json) { MS_EXCEPTION_IF_NULL(anf_node); @@ -161,10 +171,14 @@ void TbeKernelJsonCreator::GenValidInputDescJson(const std::shared_ptr bool value, const std::shared_ptr &input_ptr, const string &op_input_name, size_t input_i, std::vector *input_list) { + auto def_format = kOpFormat_NCHW; auto dtype = GetDeviceInputType(anf_node, real_input_index); auto format = GetDeviceInputFormat(anf_node, real_input_index); auto shape = GetDeviceInputShape(anf_node, real_input_index); auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index); + if (anf_node->isa() && IsNeedChangeDefaultFormat(anf_node->cast())) { + def_format = kOpFormat_NCDHW; + } if (ori_shape.empty()) { ori_shape.emplace_back(1); } @@ -172,7 +186,7 @@ void TbeKernelJsonCreator::GenValidInputDescJson(const std::shared_ptr input_desc_json[kJDtype] = dtype; input_desc_json[kJName] = op_input_name + std::to_string(input_i); input_desc_json[kJOriShape] = ori_shape; - input_desc_json[kJOriFormat] = kOpFormat_NCHW; + input_desc_json[kJOriFormat] = def_format; input_desc_json[kJShape] = shape; input_desc_json[kJFormat] = format; input_desc_json[kJValid] = value; @@ -379,6 +393,10 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr &anf_nod std::vector *output_list) { MS_EXCEPTION_IF_NULL(output_idx); MS_EXCEPTION_IF_NULL(output_list); + auto def_format = kOpFormat_NCHW; + if (anf_node->isa() && IsNeedChangeDefaultFormat(anf_node->cast())) { + def_format = kOpFormat_NCDHW; + } for (size_t i = 0; i < output_obj_num; i++) { auto dtype = GetDeviceOutputType(anf_node, *output_idx); auto format = GetDeviceOutputFormat(anf_node, *output_idx); @@ -397,7 +415,7 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr &anf_nod output_obj[kJShape] = shape; output_obj[kJFormat] = format; output_obj[kJOriShape] = ori_shape; - output_obj[kJOriFormat] = kOpFormat_NCHW; + output_obj[kJOriFormat] = def_format; output_obj[kJName] = output_ptr->name(); output_obj[kJValid] = true; output_obj[kJParamType] = output_ptr->param_type(); @@ -580,6 +598,9 @@ std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_nod format = kOpFormat_NCHW; } } + if (anf_node->isa() && IsNeedChangeDefaultFormat(anf_node->cast())) { + format = kOpFormat_NCDHW; + } return format; } @@ -619,6 +640,9 @@ std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_no format = kOpFormat_NCHW; } } + if (anf_node->isa() && IsNeedChangeDefaultFormat(anf_node->cast())) { + format = kOpFormat_NCDHW; + } return format; } @@ -818,6 +842,10 @@ void TbeKernelBuild::GenSuffixDescJson(nlohmann::json *output_desc) { void TbeKernelBuild::GenDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, size_t desc_output_idx, nlohmann::json *output_desc, FusionDataType fusion_data_type) { GenPreDescJson(output_desc); + auto def_format = kOpFormat_NCHW; + if (anf_node->isa() && IsNeedChangeDefaultFormat(anf_node->cast())) { + def_format = kOpFormat_NCDHW; + } // data_type auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx); (*output_desc)[kJDataType] = tbe::TypeIdToString(type_id); @@ -828,7 +856,7 @@ void TbeKernelBuild::GenDescJson(const std::shared_ptr &anf_ } (*output_desc)[kJName] = output_desc_name; // ori_format - (*output_desc)[kJOriFormat] = kOpFormat_NCHW; + (*output_desc)[kJOriFormat] = def_format; // ori_shape auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, node_out_idx); if (ori_shape.empty()) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc index 7c43d3a4cc..3c34bd5795 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.cc @@ -248,13 +248,57 @@ bool TbeKernelBroadCastSelecter::IsBroadCastSupportFracNZ(SupportFormat *support bool TbeKernelBroadCastSelecter::IsBroadCastSupportNDC1HWC0(SupportFormat *support_format) const { MS_EXCEPTION_IF_NULL(support_format); - return false; + if (IsSameShape()) { + if (!HasScalarInput()) { + AssignSupportFormat(kOpFormat_NDC1HWC0, support_format); + return true; + } + return false; + } + SupportFormatItem input_support_format; + SupportFormatItem output_support_format; + if (HasScalarInput()) { + for (const auto &shape : input_shapes_) { + if (IsScalarShape(shape)) { + input_support_format.emplace_back(kOpFormat_NCDHW); + } else if (!Is5DShape(shape)) { + return false; + } else if (shape[kChannelC] % kAlignmented16 != 0) { + return false; + } else { + input_support_format.emplace_back(kOpFormat_NDC1HWC0); + } + } + } else { + for (const auto &shape : input_shapes_) { + if (!Is5DShape(shape)) { + return false; + } + } + auto shape_tmp = input_shapes_[0]; + auto broadcast_c_axis = std::any_of( + input_shapes_.begin(), input_shapes_.end(), + [&shape_tmp](const std::vector &elem) { return shape_tmp.at(kChannelC) != elem.at(kChannelC); }); + if (broadcast_c_axis) { + MS_LOG(INFO) << "This node broadcast c channel."; + return false; + } + input_support_format.assign(input_num_, kOpFormat_NDC1HWC0); + } + GenOutputSupportFormat(kOpFormat_NDC1HWC0, &output_support_format); + support_format->input_format.emplace_back(input_support_format); + support_format->output_format.emplace_back(output_support_format); + return true; } bool TbeKernelBroadCastSelecter::Is4DShape(const std::vector &shape) const { return shape.size() == kShape4dDims; } +bool TbeKernelBroadCastSelecter::Is5DShape(const std::vector &shape) const { + return shape.size() == kShape5dDims; +} + bool TbeKernelBroadCastSelecter::IsSameShape() const { auto shape = input_shapes_.begin(); for (const auto &item : input_shapes_) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h index fb5f1f554b..6dae9e1a93 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_broadcast_selecter.h @@ -40,6 +40,7 @@ class TbeKernelBroadCastSelecter { bool IsSameShape() const; void PadScalarShape(std::vector *shape) const; bool Is4DShape(const std::vector &shape) const; + bool Is5DShape(const std::vector &shape) const; bool IsScalarShape(const std::vector &shape) const; bool HasScalarInput() const; void GenOutputSupportFormat(const std::string &support_format, SupportFormatItem *output_support_item) const; diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc index 61aa9dfb91..2914b36bfc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.cc @@ -72,8 +72,18 @@ bool TbeKernelReduceSelecter::IsReduceSupport5HD(SupportFormat *support_format) bool TbeKernelReduceSelecter::IsReduceSupportNDC1HWC0(SupportFormat *support_format) const { MS_EXCEPTION_IF_NULL(support_format); - // like to 5HD - return false; + if (!Is5DShape(input_shape_)) { + return false; + } + if (!keep_dims_ || axis_.empty()) { + return false; + } + auto reduce_c_axis = std::any_of(axis_.begin(), axis_.end(), [](const size_t &elem) { return (elem == kChannelC); }); + if (reduce_c_axis) { + return false; + } + AssignSupportFormat(kOpFormat_NDC1HWC0, support_format); + return true; } bool TbeKernelReduceSelecter::IsReduceSupportFracZ(SupportFormat *support_format) const { @@ -142,6 +152,8 @@ void TbeKernelReduceSelecter::AssignSupportFormat(const std::string &support_for bool TbeKernelReduceSelecter::Is4DShape(const std::vector &shape) const { return shape.size() == kShape4dDims; } +bool TbeKernelReduceSelecter::Is5DShape(const std::vector &shape) const { return shape.size() == kShape5dDims; } + void TbeKernelReduceSelecter::PadScalarShape(std::vector *shape) const { MS_EXCEPTION_IF_NULL(shape); if (shape->empty()) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h index b68bfd60ca..c16e5c67f2 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h @@ -39,6 +39,7 @@ class TbeKernelReduceSelecter { void GetReduceAttrKeepDim(); void AssignSupportFormat(const std::string &support_format_str, SupportFormat *support_format) const; bool Is4DShape(const std::vector &shape) const; + bool Is5DShape(const std::vector &shape) const; void PadScalarShape(std::vector *shape) const; CNodePtr cnode_ptr_; std::vector input_shape_{}; diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc index 93207f0649..bd4d8334d5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc @@ -187,6 +187,9 @@ void TbeKernelSelect::GetBroadcastPatternKernelInfo(const OpInfo &op_info) { if (!broadcast_selecter.IsBroadCastSupportFracNZ(&support_format)) { MS_LOG(INFO) << "Node(" << node_name_ << ") does not support FracNZ."; } + if (!broadcast_selecter.IsBroadCastSupportNDC1HWC0(&support_format)) { + MS_LOG(INFO) << "Node(" << node_name_ << ") does not support NDC1HWC0."; + } PrintSupportedFormat(support_format); OpInfo op_info_new; CreateNewOpInfo(op_info, support_format, &op_info_new); @@ -281,10 +284,8 @@ bool TbeKernelSelect::IsShapeMatchFormat(const std::vector &shape, const return true; } // not support format: - // 1 NDHWC with shape size != 5 - // 3 !NDHWC with shape size > 4 - if ((format == kOpFormat_NDHWC && shape.size() != kShape5dDims) || - (format != kOpFormat_NDHWC && shape.size() > kShape4dDims)) { + // 1 NCDHW with shape size != 5 + if (format == kOpFormat_NCDHW && shape.size() != kShape5dDims) { MS_LOG(INFO) << "Warning: Shape format check failed, format: " << format << ", size: " << shape.size(); return false; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index e9e85024c4..9c227abf74 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -32,7 +32,7 @@ namespace mindspore { namespace opt { using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; namespace { -const std::set kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW}; +const std::set kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW}; AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const KernelSelectPtr &kernel_select, const std::vector &dst_shape) { std::vector trans_inputs; @@ -70,9 +70,17 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt CNodePtr trans_data = nullptr; MS_EXCEPTION_IF_NULL(node); // Init + std::string default_format = kOpFormat_DEFAULT; + + if (node->isa() && AnfAlgo::HasNodeAttr("io_format", node->cast())) { + auto attr = AnfAlgo::GetNodeAttr(node, "io_format"); + if (attr == kOpFormat_NCDHW) { + default_format = kOpFormat_NCDHW; + } + } AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast(), insert_index) : node; - std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, insert_index); - std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : kOpFormat_DEFAULT; + std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index); + std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format; std::vector padding_axis = is_insert_input ? AnfAlgo::GetInputReshapeType(node, insert_index) : AnfAlgo::GetOutputReshapeType(node, insert_index); auto input_node_out_shape = is_insert_input ? AnfAlgo::GetPrevNodeOutputInferShape(node, insert_index) diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 1306d49931..03bdfaafcc 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -369,6 +369,26 @@ void KernelGraph::CheckLoop() { } } +void ReSetParameterValueNodeFormatAndType(const AnfNodePtr &node, const std::string &format) { + MS_EXCEPTION_IF_NULL(node); + auto kernel_build_info_builder = std::make_shared(); + MS_EXCEPTION_IF_NULL(kernel_build_info_builder); + kernel_build_info_builder->SetOutputsFormat({format}); + kernel_build_info_builder->SetOutputsDeviceType({AnfAlgo::GetOutputInferDataType(node, 0)}); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get()); +} + +void KernelGraph::ResetInFormat(const AnfNodePtr &node, const std::string &format) const { + MS_EXCEPTION_IF_NULL(node); + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); i++) { + auto in_node = AnfAlgo::GetInputNode(node->cast(), i); + MS_EXCEPTION_IF_NULL(in_node); + if (in_node->isa() || in_node->isa()) { + ReSetParameterValueNodeFormatAndType(in_node, format); + } + } +} + CNodePtr KernelGraph::NewCNode(const std::vector &inputs) { auto cnode = FuncGraph::NewCNode(inputs); MS_EXCEPTION_IF_NULL(cnode); @@ -378,6 +398,12 @@ CNodePtr KernelGraph::NewCNode(const std::vector &inputs) { AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode); } SetKernelInfoForNode(cnode); + if (AnfAlgo::HasNodeAttr("io_format", cnode)) { + auto attr = AnfAlgo::GetNodeAttr(cnode, "io_format"); + if (attr == kOpFormat_NCDHW) { + ResetInFormat(cnode, kOpFormat_NCDHW); + } + } AnfAlgo::SetGraphId(graph_id_, cnode.get()); return cnode; } diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 50334e58c2..cfab2d76b1 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -273,6 +273,7 @@ class KernelGraph : public FuncGraph { // remove value node form graph bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); void SetKernelInfoForNode(const AnfNodePtr &node) const; + void ResetInFormat(const AnfNodePtr &node, const std::string &format) const; AnfNodePtr MakeValueNode(const AnfNodePtr &node); void VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, std::unordered_set *visited_nodes); diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index b8729190e6..f274009537 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -266,6 +266,41 @@ std::vector Nc1hwc0DeviceShape(const std::vector &shape) { return device_shape; } +std::vector Ndc1hwc0DeviceShape(const std::vector &shape) { + // NCDHW + if (shape.size() != 5) { + MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size(); + } + std::vector device_shape; + const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; + const size_t C0 = kCubeSize; + device_shape.push_back(shape[0]); + device_shape.push_back(shape[2]); + device_shape.push_back(C1); + device_shape.push_back(shape[3]); + device_shape.push_back(shape[4]); + device_shape.push_back(C0); + return device_shape; +} + +std::vector Fracz3DDeviceShape(const std::vector &shape) { + // NCDHW -> Frac_Z_3D + if (shape.size() != 5) { + MS_LOG(EXCEPTION) << "Check dims failed, expect shape dim 5, but got shape dim : " << shape.size(); + } + std::vector device_shape; + const size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; + const size_t N1 = (shape[0] + kCubeSize - 1) / kCubeSize; + device_shape.push_back(shape[2]); + device_shape.push_back(C1); + device_shape.push_back(shape[3]); + device_shape.push_back(shape[4]); + device_shape.push_back(N1); + device_shape.push_back(kCubeSize); + device_shape.push_back(kCubeSize); + return device_shape; +} + std::vector C1hwncoc0DeviceShape(const std::vector &shape) { if (!CheckDims(shape)) { MS_LOG(EXCEPTION) << "Check dims failed."; @@ -310,7 +345,7 @@ std::vector Nc1hwc04DeviceShape(const std::vector &shape) { return device_shape; } -std::vector NdhwcDeviceShape(const std::vector &shape) { +std::vector NcdhwDeviceShape(const std::vector &shape) { if (shape.size() < kNdhwc) { MS_LOG(EXCEPTION) << "Shape dims must be 5 when format is ndhwc."; } @@ -405,7 +440,9 @@ std::vector TransShapeToDevice(const std::vector &shape, const s {kOpFormat_C1HWNCoC0, C1hwncoc0DeviceShape}, {kOpFormat_FRACTAL_Z_C04, FracZc04DeviceShape}, {kOpFormat_NC1HWC0_C04, Nc1hwc04DeviceShape}, - {kOpFormat_NDHWC, NdhwcDeviceShape}}; + {kOpFormat_NCDHW, NcdhwDeviceShape}, + {kOpFormat_NDC1HWC0, Ndc1hwc0DeviceShape}, + {kOpFormat_FRACTAL_Z_3D, Fracz3DDeviceShape}}; if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { return shape; @@ -441,7 +478,7 @@ std::vector TransShapeToDevice(const std::vector &shape, const s device_shape.push_back(kCubeSize); return device_shape; } - if (shape.size() != kNchwDims) { + if (shape.size() != kNchwDims && shape.size() != 5) { MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; temp_shape = PaddingShapeTo4dByDefault(shape); } @@ -496,7 +533,9 @@ bool TransFormat(const FormatArgs &args, void *result) { const std::map format_trans_map{ {kOpFormat_FRAC_Z, NchwToFracZ}, {kOpFormat_FRAC_NZ, NchwToFracNz}, {kOpFormat_NC1HWC0, NchwToNc1hwc0}, {kOpFormat_C1HWNCoC0, NchwToC1hwncoc0}, - {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}}; + {kOpFormat_FRACTAL_Z_C04, NchwToFracZc04}, {kOpFormat_NC1HWC0_C04, NchwToNc1hwc04}, + {kOpFormat_NDC1HWC0, NcdhwToNdc1hwc0}}; + MS_LOG(DEBUG) << "Start trans format."; if (abstract::TypeIdSize(args.src_data_type) < 1) { MS_LOG(ERROR) << "Invalid datatype.."; @@ -514,11 +553,11 @@ bool TransFormat(const FormatArgs &args, void *result) { bool TransFormatFromDeviceToHost(const FormatArgs &args, void *result) { using FormatTransfer = std::function; - const std::map format_trans_map{{kOpFormat_FRAC_Z, FracZToNchw}, - {kOpFormat_FRAC_NZ, FracNzToNchw}, - {kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, - {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, - {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}}; + const std::map format_trans_map{ + {kOpFormat_FRAC_Z, FracZToNchw}, {kOpFormat_FRAC_NZ, FracNzToNchw}, + {kOpFormat_NC1HWC0, Nc1hwc0ToNchw}, {kOpFormat_C1HWNCoC0, C1hwncoc0ToNchw}, + {kOpFormat_NC1HWC0_C04, Nc1hwc04ToNchw}, {kOpFormat_NDC1HWC0, Ndc1hwc0ToNcdhw}}; + MS_LOG(DEBUG) << "Start trans format."; if (abstract::TypeIdSize(args.src_data_type) < 1) { MS_LOG(ERROR) << "Invalid datatype.."; @@ -1106,5 +1145,119 @@ bool C1hwncoc0ToNchw(const FormatArgs &args, void *result) { } return true; } + +bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result) { + MS_LOG(DEBUG) << "Trans from ndc1hwc0 to ncdhw"; + MS_EXCEPTION_IF_NULL(result); + + if (args.host_shape.size() != 5) { + MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size(); + return false; + } + auto size = abstract::TypeIdSize(args.src_data_type); + if (size < 1) { + MS_LOG(ERROR) << "Illegal dtype."; + return false; + } + auto total_size = abstract::ShapeSize(args.device_shape) * size; + if (total_size != args.device_size) { + MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; + return false; + } + auto n = args.host_shape[0]; + auto c = args.host_shape[1]; + auto d = args.host_shape[2]; + auto h = args.host_shape[3]; + auto w = args.host_shape[4]; + auto c1 = args.device_shape[2]; + auto c0 = args.device_shape[5]; + const size_t cdhw = c * d * h * w; + const size_t dhw = d * h * w; + const size_t hw = h * w; + const size_t dc1hwc0 = d * c1 * h * w * c0; + const size_t c1hwc0 = c1 * h * w * c0; + const size_t hwc0 = h * w * c0; + const size_t wc0 = w * c0; + + for (size_t n_i = 0; n_i < n; n_i++) { + size_t n_head = n_i * cdhw; + for (size_t c_i = 0; c_i < c; c_i++) { + size_t c_head = n_head + c_i * dhw; + for (size_t d_i = 0; d_i < d; d_i++) { + size_t d_head = c_head + d_i * hw; + for (size_t h_i = 0; h_i < h; h_i++) { + size_t h_head = d_head + h_i * w; + for (size_t w_i = 0; w_i < w; w_i++) { + size_t dst_i = h_head + w_i; + size_t c1_i = c_i / c0; + size_t c0_i = c_i % c0; + auto src_idx = n_i * dc1hwc0 + d_i * c1hwc0 + c1_i * hwc0 + h_i * wc0 + w_i * c0 + c0_i; + SetData(size, false, src_idx, dst_i, args, result); + } + } + } + } + } + return true; +} + +bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result) { + MS_LOG(DEBUG) << "Trans from ncdhw to ndc1hwc0"; + MS_EXCEPTION_IF_NULL(result); + + if (args.host_shape.size() != 5) { + MS_LOG(ERROR) << "Illegal host shape dim, expect dim: 5, but got " << args.host_shape.size(); + return false; + } + auto size = abstract::TypeIdSize(args.src_data_type); + if (size < 1) { + MS_LOG(ERROR) << "Illegal dtype."; + return false; + } + auto total_size = abstract::ShapeSize(args.device_shape) * size; + if (total_size != args.device_size) { + MS_LOG(ERROR) << "Illegal total data size, total_size:" << total_size << ", device_size:" << args.device_size; + return false; + } + + auto n = args.host_shape[0]; + auto c = args.host_shape[1]; + auto d = args.host_shape[2]; + auto h = args.host_shape[3]; + auto w = args.host_shape[4]; + auto c0 = kCubeSize; + auto c1 = DivCeil(c, c0); + const size_t cdhw = c * d * h * w; + const size_t dhw = d * h * w; + const size_t hw = h * w; + const size_t dc1hwc0 = d * c1 * h * w * c0; + const size_t c1hwc0 = c1 * h * w * c0; + const size_t hwc0 = h * w * c0; + const size_t wc0 = w * c0; + + for (size_t n_i = 0; n_i < n; n_i++) { + size_t n_head = n_i * dc1hwc0; + for (size_t d_i = 0; d_i < d; d_i++) { + size_t d_head = n_head + d_i * c1hwc0; + for (size_t c1_i = 0; c1_i < c1; c1_i++) { + size_t c1_head = d_head + c1_i * hwc0; + for (size_t h_i = 0; h_i < h; h_i++) { + size_t h_head = c1_head + h_i * wc0; + for (size_t w_i = 0; w_i < w; w_i++) { + size_t w_head = h_head + w_i * c0; + for (size_t c0_i = 0; c0_i < c0; c0_i++) { + size_t dst_i = c0_i + w_head; + size_t c_i = c0_i + c1_i * c0; + size_t src_i = n_i * cdhw + c_i * dhw + d_i * hw + h_i * w + w_i; + auto pad_zero = c_i >= c; + SetData(size, pad_zero, src_i, dst_i, args, result); + } + } + } + } + } + } + return true; +} } // namespace trans } // namespace mindspore diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index cad059eaa2..3275d9e364 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -66,6 +66,8 @@ bool NchwToNc1hwc0(const FormatArgs &args, void *result); bool NchwToFracZc04(const FormatArgs &args, void *result); bool NchwToNc1hwc04(const FormatArgs &args, void *result); bool NchwToC1hwncoc0(const FormatArgs &args, void *result); +bool NcdhwToNdc1hwc0(const FormatArgs &args, void *result); + // device to host bool ToNchw(const FormatArgs &args, void *result); bool FracZToNchw(const FormatArgs &args, void *result); @@ -73,6 +75,7 @@ bool FracNzToNchw(const FormatArgs &args, void *result); bool Nc1hwc0ToNchw(const FormatArgs &args, void *result); bool Nc1hwc04ToNchw(const FormatArgs &args, void *result); bool C1hwncoc0ToNchw(const FormatArgs &args, void *result); +bool Ndc1hwc0ToNcdhw(const FormatArgs &args, void *result); } // namespace trans } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index ec1c1cacf6..78a0176d51 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -292,7 +292,7 @@ bool AscendDeviceAddress::SyncDeviceToHost(const ShapeVector &shape, size_t size if (host_shape.empty()) { host_shape.emplace_back(1); } - if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) { + if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NCDHW) { if (type_id_ == type) { SyncMemory(host_ptr, ptr_, size, RT_MEMCPY_DEVICE_TO_HOST); sync_ok = true; @@ -454,7 +454,7 @@ std::vector AscendDeviceAddress::GetWorkspaceSizeList(const nlohmann::js std::vector AscendDeviceAddress::GetDeviceShape(std::vector *host_shape) const { std::vector device_shape; - if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { + if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) { device_shape = trans::TransShapeToDevice(*host_shape, format_); } else { if (host_shape_.empty()) { @@ -531,7 +531,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size if (host_shape.empty()) { host_shape.emplace_back(1); } - if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NDHWC) { + if (format_ == kOpFormat_NCHW || format_ == kOpFormat_DEFAULT || format_ == kOpFormat_NCDHW) { if (type_id_ == type) { SyncMemory(ptr_, host_ptr, size, RT_MEMCPY_HOST_TO_DEVICE); sync_ok = true; @@ -575,7 +575,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh host_shape.emplace_back(1); } std::vector device_shape; - if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NDHWC) { + if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) { device_shape = trans::TransShapeToDevice(host_shape, format_); } else { host_shape = trans::PaddingShapeTo4d(host_shape); diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index d0e9262964..18cf9177d1 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -81,6 +81,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { string priority_matched_format = kOpFormat_NC1HWC0; bool is_init = false; bool need_change_nd = false; + bool is_5d_input = false; for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); if (AnfAlgo::IsFeatureMapInput(cnode, index) && @@ -93,14 +94,21 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { priority_matched_format = kOpFormat_DEFAULT; } auto input_shape_size = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index).size(); + if (input_shape_size == 5) { + is_5d_input = true; + } need_change_nd = (need_change_nd || (input_shape_size != 4 && input_shape_size > 1)); } if (need_change_nd && priority_matched_format != kOpFormat_FRAC_NZ) { priority_matched_format = kOpFormat_DEFAULT; } + if (is_5d_input && priority_matched_format != kOpFormat_FRAC_NZ) { + priority_matched_format = kOpFormat_NDC1HWC0; + } AnfAlgo::SetNodeAttr(kPriChoosenFormat, MakeValue(priority_matched_format), cnode); return priority_matched_format; } + /** * Compare two vector by priority, select a better vector, like compare two num, first compare highest num location, * if equal then next num location @@ -157,7 +165,8 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons if (kernel_build_info.GetInputFormat(input_index) == pri_match_format) { (*cur_kernelinfo_match_counts)[MATCH_SPECIAL_FORMAT_COUNT] += base_score; } - if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT) { + if (kernel_build_info.GetInputFormat(input_index) == kOpFormat_DEFAULT || + kernel_build_info.GetInputFormat(input_index) == kOpFormat_NCDHW) { (*cur_kernelinfo_match_counts)[MATCH_DEFAULT_FORMAT_COUNT] += base_score; } } @@ -376,7 +385,9 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { std::vector output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)}; if (IsValueNode(input_kernel_node) && AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { - if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM) { + if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM || + selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_Z_3D || + selected_kernel_info->GetInputFormat(input_index) != kOpFormat_NDC1HWC0) { output_format = {selected_kernel_info->GetInputFormat(input_index)}; } builder->SetOutputsFormat(output_format); @@ -386,7 +397,9 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { continue; } if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { - if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM) { + if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM || + selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_Z_3D || + selected_kernel_info->GetInputFormat(input_index) != kOpFormat_NDC1HWC0) { output_format = {selected_kernel_info->GetInputFormat(input_index)}; } builder->SetOutputsFormat(output_format); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 89bdc00065..28992f123c 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -386,11 +386,23 @@ constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04"; constexpr auto kOpFormat_NDHWC = "NDHWC"; +constexpr auto kOpFormat_NCDHW = "NCDHW"; +constexpr auto kOpFormat_DHWNC = "DHWNC"; +constexpr auto kOpFormat_DHWCN = "DHWCN"; +constexpr auto kOpFormat_NDC1HWC0 = "NDC1HWC0"; +constexpr auto kOpFormat_FRACTAL_Z_3D = "FRACTAL_Z_3D"; constexpr auto kOpFormat_FRACTAL_ZN_LSTM = "FRACTAL_ZN_LSTM"; -const std::set kOpFormatList = { - kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, - kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, - kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM}; + +const std::set kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, + kOpFormat_ND, kOpFormat_NCHW, + kOpFormat_NHWC, kOpFormat_HWCN, + kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, + kOpFormat_C1HWNCoC0, kOpFormat_FRAC_NZ, + kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, + kOpFormat_NDHWC, kOpFormat_FRACTAL_ZN_LSTM, + kOpFormat_NDC1HWC0, kOpFormat_NCDHW, + kOpFormat_FRACTAL_Z_3D, kOpFormat_DHWNC, + kOpFormat_DHWCN}; const std::set kDefaultCompatibleFormat = {kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN}; const std::set kOptOperatorSet = {kMomentumOpName, kApplyMomentumOpName, @@ -427,8 +439,8 @@ const std::set kOptOperatorSet = {kMomentumOpName, kSparseApplyProximalAdagradOpName}; const std::set kHWSpecialFormatSet = { - kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, - kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM}; + kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, + kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z}; const std::set kFloatDataTypeSet = {kNumberTypeFloat16, kNumberTypeFloat32};