diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index 9c227abf74..89892baa10 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -32,6 +32,7 @@ namespace mindspore { namespace opt { using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; namespace { +const std::set k3DFormatSet = {kOpFormat_NCDHW, kOpFormat_NDC1HWC0, kOpFormat_FRACTAL_Z_3D}; const std::set kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NCDHW}; AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, const KernelSelectPtr &kernel_select, const std::vector &dst_shape) { @@ -64,20 +65,30 @@ void SetTransNodeAttr(const CNodePtr &trans_node) { } } -AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { - AnfNodePtr trans_node = nullptr; - CNodePtr trans_data = nullptr; +std::string InitDefaultFormat(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - // Init std::string default_format = kOpFormat_DEFAULT; - if (node->isa() && AnfAlgo::HasNodeAttr("io_format", node->cast())) { auto attr = AnfAlgo::GetNodeAttr(node, "io_format"); if (attr == kOpFormat_NCDHW) { default_format = kOpFormat_NCDHW; } + } else if (node->isa() || node->isa()) { + auto out_format = AnfAlgo::GetOutputFormat(node, 0); + if (k3DFormatSet.find(out_format) != k3DFormatSet.end()) { + default_format = kOpFormat_NCDHW; + } } + return default_format; +} + +AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { + AnfNodePtr trans_node = nullptr; + CNodePtr trans_data = nullptr; + MS_EXCEPTION_IF_NULL(node); + // Init + std::string default_format = InitDefaultFormat(node); AnfNodePtr input_node = is_insert_input ? AnfAlgo::GetInputNode(node->cast(), insert_index) : node; std::string input_format = is_insert_input ? default_format : AnfAlgo::GetOutputFormat(node, insert_index); std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, insert_index) : default_format; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index 78a0176d51..596a9ebb47 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -575,7 +575,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &sh host_shape.emplace_back(1); } std::vector device_shape; - if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW) { + if (format_ == kOpFormat_FRAC_NZ || format_ == kOpFormat_NCDHW || kOpFormat_NDC1HWC0) { device_shape = trans::TransShapeToDevice(host_shape, format_); } else { host_shape = trans::PaddingShapeTo4d(host_shape);