From 5d225f934f43662b7244ce9cfa033215a0b6db00 Mon Sep 17 00:00:00 2001 From: lianliguang Date: Wed, 8 Apr 2020 17:42:56 +0800 Subject: [PATCH] change the padding strategy & refactor insert transdata --- mindspore/ccsrc/common/trans.cc | 132 ++++++++---- mindspore/ccsrc/common/trans.h | 6 +- .../device/ascend/ascend_device_address.cc | 4 +- .../device/ascend/ascend_kernel_runtime.cc | 5 +- .../device/ascend/kernel_select_ascend.cc | 13 +- mindspore/ccsrc/device/kernel_adjust.cc | 4 +- mindspore/ccsrc/device/kernel_info.h | 4 + mindspore/ccsrc/device/kernel_runtime.cc | 20 +- mindspore/ccsrc/kernel/kernel_build_info.cc | 30 ++- mindspore/ccsrc/kernel/kernel_build_info.h | 8 +- .../pre_activate/ascend/ascend_helper.cc | 189 ++++++----------- .../ccsrc/pre_activate/ascend/ascend_helper.h | 2 +- .../format_type/deal_ref_trans_and_cast.cc | 6 +- .../ascend/ir_fusion/transdata_split.cc | 198 +++++++++--------- .../ccsrc/session/anf_runtime_algorithm.cc | 104 +++++---- .../ccsrc/session/anf_runtime_algorithm.h | 5 + mindspore/ccsrc/session/ascend_session.cc | 5 +- mindspore/ccsrc/session/kernel_graph.cc | 18 +- mindspore/ccsrc/session/session_basic.cc | 13 +- mindspore/ccsrc/utils/utils.h | 4 +- ...er_norm_beta_gamma_backprop_fusion_test.cc | 2 + .../cpp/session/anf_runtime_algorithm_test.cc | 26 +-- tests/ut/cpp/session/kernel_graph_test.cc | 6 +- 23 files changed, 425 insertions(+), 379 deletions(-) diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 380c51bcf9..a2b9f7ef24 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -20,6 +20,8 @@ #include #include "./securec.h" #include "common/utils.h" +#include "session/anf_runtime_algorithm.h" +#include "kernel/kernel.h" #include "device/convert_tensor_utils.h" #include "utils/convert_utils.h" #include "utils/log_adapter.h" @@ -27,6 +29,33 @@ namespace mindspore { namespace trans { +namespace { +std::vector PaddingShapeTo4dByDefault(const std::vector &shape) { + std::vector shape_4d(4, 1); + switch (shape.size()) { + case 0: + return shape_4d; + case 1: + shape_4d[1] = shape[0]; + break; + case 2: + shape_4d[1] = shape[0]; + shape_4d[2] = shape[1]; + break; + case 3: + shape_4d[1] = shape[0]; + shape_4d[2] = shape[1]; + shape_4d[3] = shape[2]; + break; + case 4: + std::copy(shape.begin(), shape.end(), shape_4d.begin()); + break; + default: + MS_LOG(EXCEPTION) << "Unexpect shape size = " << shape.size(); + } + return shape_4d; +} +} // namespace const size_t kNchwDims = 4; const std::map type_map = {{kNumberTypeBool, 1}, {kNumberTypeInt, 4}, {kNumberTypeInt8, 1}, {kNumberTypeInt16, 2}, {kNumberTypeInt32, 4}, {kNumberTypeInt64, 8}, @@ -154,38 +183,64 @@ size_t TypeIdSize(const TypeId data_type) { return unsupported_type_error; } -std::vector TransShapeTo4d(const std::vector &shape) { +bool IsNeedPadding(const std::string &format, const size_t shape_size) { + if (shape_size == 0) { + return false; + } + if (format == kOpFormat_DEFAULT || format == kOpFormat_FRAC_NZ) { + return false; + } else if (shape_size < 4) { + return true; + } + return false; +} + +std::vector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) { + std::vector shape; + std::vector host_shape; + if (node->isa()) { + auto value_node = node->cast(); + auto node_value = value_node->value(); + auto tensor = node_value->cast(); + if (tensor == nullptr) { + MS_LOG(EXCEPTION) << " the node[ " << node->DebugString() << "]'s cannot convert "; + } + shape = tensor->shape(); + (void)std::transform(shape.begin(), shape.end(), std::back_inserter(host_shape), IntToSize); + if (host_shape.empty()) { + host_shape.push_back(1); + } + } else { + host_shape = AnfAlgo::GetOutputInferShape(node, index); + } + if (trans::IsNeedPadding(AnfAlgo::GetOutputFormat(node, 0), host_shape.size())) { + host_shape = trans::PaddingShapeTo4d(host_shape, AnfAlgo::GetOutputReshapeType(node, 0)); + } + std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(shape), SizeToInt); + return shape; +} + +std::vector PaddingShapeTo4d(const std::vector &shape, const std::vector &padding_axis) { + if (padding_axis.empty() || shape.size() != padding_axis.size()) { + return PaddingShapeTo4dByDefault(shape); + } std::vector shape_4d(4, 1); - switch (shape.size()) { - case 0: - break; - case 1: - shape_4d[1] = shape[0]; - break; - case 2: - shape_4d[0] = shape[0]; - shape_4d[1] = shape[1]; - break; - case 3: - MS_LOG(EXCEPTION) << "Unexpected shape size = 3,it should has a default format"; - case 4: - for (size_t i = 0; i < 4; ++i) { - shape_4d[i] = shape[i]; - } - break; - default: - MS_LOG(EXCEPTION) << "Unexpected shape size = " << shape.size(); + for (size_t index = 0; index < padding_axis.size(); index++) { + shape_4d[padding_axis[index]] = shape[index]; } return shape_4d; } std::vector TransShapeToDevice(const std::vector &shape, const std::string &format) { + if (format == kOpFormat_ND || format == kOpFormat_DEFAULT) { + return shape; + } + auto temp_shape = shape; std::vector device_shape; if (format == kOpFormat_FRAC_NZ) { if (shape.size() < 2) { - MS_EXCEPTION(NotSupportError) << "Format " << format << " is not support shape " << shape.size(); - } - if (shape.size() > 2) { + MS_LOG(EXCEPTION) << "Format" << format << " is not support shape " << shape.size(); + } else { (void)std::copy(shape.begin(), shape.end() - 2, std::back_inserter(device_shape)); } auto h1 = (shape[shape.size() - 2] - 1) / kCubeSize + 1; @@ -197,35 +252,36 @@ std::vector TransShapeToDevice(const std::vector &shape, const s return device_shape; } if (shape.size() != 4) { - MS_LOG(EXCEPTION) << "shape_4d size should be 4"; + MS_LOG(WARNING) << "Get Device Shape using a shape size is less than 4 ,should be Padding shape by Default firstly"; + temp_shape = PaddingShapeTo4dByDefault(shape); } if (format == kOpFormat_NC1HWC0) { - size_t C1 = (shape[1] + kCubeSize - 1) / kCubeSize; + size_t C1 = (temp_shape[1] + kCubeSize - 1) / kCubeSize; size_t C0 = kCubeSize; - device_shape.push_back(shape[0]); + device_shape.push_back(temp_shape[0]); device_shape.push_back(C1); - device_shape.push_back(shape[2]); - device_shape.push_back(shape[3]); + device_shape.push_back(temp_shape[2]); + device_shape.push_back(temp_shape[3]); device_shape.push_back(C0); return device_shape; } else if (format == kOpFormat_FRAC_Z) { - size_t cout16 = ((shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize; - size_t cin16 = ((shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize; - device_shape.push_back(shape[2] * shape[3] * cin16 / kCubeSize); + size_t cout16 = ((temp_shape[0] + kCubeSize - 1) / kCubeSize) * kCubeSize; + size_t cin16 = ((temp_shape[1] + kCubeSize - 1) / kCubeSize) * kCubeSize; + device_shape.push_back(temp_shape[2] * temp_shape[3] * cin16 / kCubeSize); device_shape.push_back(cout16 / kCubeSize); device_shape.push_back(kCubeSize); device_shape.push_back(kCubeSize); return device_shape; } else if (format == kOpFormat_NHWC) { - device_shape.push_back(shape[0]); - device_shape.push_back(shape[2]); - device_shape.push_back(shape[3]); - device_shape.push_back(shape[1]); + device_shape.push_back(temp_shape[0]); + device_shape.push_back(temp_shape[2]); + device_shape.push_back(temp_shape[3]); + device_shape.push_back(temp_shape[1]); return device_shape; - } else if (format == kOpFormat_NCHW) { - return shape; } else if (format == kOpFormat_HWCN) { - return {shape[2], shape[3], shape[1], shape[0]}; + return {temp_shape[2], temp_shape[3], temp_shape[1], temp_shape[0]}; + } else if (format == kOpFormat_NCHW) { + return temp_shape; } MS_LOG(EXCEPTION) << "Unexpected format[" << format << "]"; } diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index cf815985ff..4bebdde814 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -24,6 +24,7 @@ #include #include #include "ir/dtype.h" +#include "kernel/kernel.h" #include "ir/dtype/type.h" namespace mindspore { @@ -49,7 +50,10 @@ size_t TypeIdSize(const TypeId data_type); size_t ShapeSize(const std::vector &shape); size_t CubeSizeByType(const TypeId data_type); -std::vector TransShapeTo4d(const std::vector &shape); +std::vector PaddingShapeTo4d(const std::vector &shape, + const std::vector &padding_axis = {}); +std::vector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index); +bool IsNeedPadding(const std::string &format, const size_t shape_size); std::vector TransShapeToDevice(const std::vector &shape, const std::string &format); bool TransDataType(const TypeIdArgs &args, void *result); bool TransFormat(const FormatArgs &args, void *result); diff --git a/mindspore/ccsrc/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/device/ascend/ascend_device_address.cc index 93f039af0e..69d1918163 100644 --- a/mindspore/ccsrc/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/device/ascend/ascend_device_address.cc @@ -141,7 +141,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vectorisa()) { + continue; + } + if (real_input_node->isa() && !AnfAlgo::IsParameterWeight(real_input_node->cast())) { + continue; + } std::shared_ptr builder = std::make_shared(); // we set special device info of a input tensor. diff --git a/mindspore/ccsrc/device/kernel_adjust.cc b/mindspore/ccsrc/device/kernel_adjust.cc index 9a6f48025f..c1588d7d53 100644 --- a/mindspore/ccsrc/device/kernel_adjust.cc +++ b/mindspore/ccsrc/device/kernel_adjust.cc @@ -25,6 +25,7 @@ #include "session/anf_runtime_algorithm.h" #include "utils/context/ms_context.h" +#include "common/trans.h" #include "utils/config_manager.h" #include "common/utils.h" #include "kernel/kernel_build_info.h" @@ -391,7 +392,8 @@ bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr &c auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); MS_EXCEPTION_IF_NULL(device_address); tensor->set_device_address(device_address); - if (!device_address->SyncHostToDevice(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c(false))) { MS_LOG(INFO) << "SyncHostToDevice failed."; return false; diff --git a/mindspore/ccsrc/device/kernel_info.h b/mindspore/ccsrc/device/kernel_info.h index 9352158774..33ddda83c9 100644 --- a/mindspore/ccsrc/device/kernel_info.h +++ b/mindspore/ccsrc/device/kernel_info.h @@ -31,6 +31,7 @@ class KernelInfo { public: KernelInfo() { kernel_mod_ = nullptr; + is_feature_map_ = false; select_kernel_build_info_ = nullptr; output_address_list_ = {}; workspace_address_list_ = {}; @@ -45,6 +46,7 @@ class KernelInfo { void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) { select_kernel_build_info_ = select_kernel_build_info; } + void SetFeatureMapFlag(bool flag) { is_feature_map_ = flag; } const DeviceAddress *GetOutputAddr(size_t index) const; DeviceAddressPtr GetMutableOutputAddr(size_t index) const; bool OutputAddrExist(size_t index) const; @@ -63,8 +65,10 @@ class KernelInfo { void set_graph_id(uint32_t graph_id) { graph_id_ = graph_id; } uint32_t graph_id() const { return graph_id_; } bool operator==(const KernelInfo &other) const; + bool is_feature_map() const { return is_feature_map_; } private: + bool is_feature_map_; kernel::KernelBuildInfoPtr select_kernel_build_info_; std::vector> output_address_list_; std::vector> workspace_address_list_; diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index eebc650347..303f2cc873 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -105,7 +105,7 @@ size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &nod std::vector shape = AnfAlgo::GetOutputDeviceShape(node, output_index); auto format = AnfAlgo::GetOutputFormat(node, output_index); if (shape.empty() && format != kOpFormat_DEFAULT) { - shape = trans::TransShapeTo4d(shape); + shape = trans::PaddingShapeTo4d(shape, AnfAlgo::GetOutputReshapeType(node, output_index)); shape = trans::TransShapeToDevice(shape, format); } // scalar's output shape is a empty vector @@ -401,8 +401,9 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const auto address = CreateDeviceAddress(ptr, node_size, AnfAlgo::GetOutputFormat(value_node, output_idx), output_type_id); MS_EXCEPTION_IF_NULL(address); AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); - if (!address->SyncHostToDevice(tensor->shape(), tensor_size, tensor->data_type(), tensor->data_c(false))) { - MS_EXCEPTION(NotExistsError) << "kValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is" + if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(), + tensor->data_c(false))) { + MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is" << AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is " << AnfAlgo::GetOutputInferDataType(value_node, output_idx); } @@ -421,19 +422,6 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(node_value); if (node_value->isa()) { AssignValueNodeTensor(value_node, node_value, 0); - } else if (node_value->isa()) { - auto value_tuple = node_value->cast(); - if (value_tuple == nullptr) { - MS_LOG(WARNING) << "value_tuple is null"; - continue; - } - size_t i = 0; - auto value_list = value_tuple->value(); - for (auto value_ptr : value_list) { - if (value_ptr->isa()) { - AssignValueNodeTensor(value_node, value_ptr, i++); - } - } } else if (node_value->isa()) { auto value = GetValue(node_value); size_t tensor_size = value.size(); diff --git a/mindspore/ccsrc/kernel/kernel_build_info.cc b/mindspore/ccsrc/kernel/kernel_build_info.cc index c52f71c136..038c06d8ed 100644 --- a/mindspore/ccsrc/kernel/kernel_build_info.cc +++ b/mindspore/ccsrc/kernel/kernel_build_info.cc @@ -59,30 +59,20 @@ size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); } size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); } -bool KernelBuildInfo::GetInputReshapeType(size_t input_index, std::vector *reshape_type) const { - MS_EXCEPTION_IF_NULL(reshape_type); - reshape_type->clear(); +std::vector KernelBuildInfo::GetInputReshapeType(size_t input_index) const { if (input_index >= input_reshape_type_.size()) { - MS_LOG(WARNING) << "The index [" << input_index << "] is exceed the number of input node size " - << input_reshape_type_.size(); - return false; + MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size " + << input_reshape_type_.size(); } - (void)std::copy(input_reshape_type_[input_index].begin(), input_reshape_type_[input_index].end(), - std::inserter(*reshape_type, (*reshape_type).begin())); - return true; + return input_reshape_type_[input_index]; } -bool KernelBuildInfo::GetOutputReshapeType(size_t output_index, std::vector *reshape_type) const { - MS_EXCEPTION_IF_NULL(reshape_type); - reshape_type->clear(); +std::vector KernelBuildInfo::GetOutputReshapeType(size_t output_index) const { if (output_index >= output_reshape_type_.size()) { - MS_LOG(WARNING) << "The index [" << output_index << "] is exceed the number of output node dixr" - << output_reshape_type_.size(); - return false; + MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size " + << output_reshape_type_.size(); } - (void)std::copy(output_reshape_type_[output_index].begin(), output_reshape_type_[output_index].end(), - std::inserter(*reshape_type, (*reshape_type).begin())); - return true; + return output_reshape_type_[output_index]; } std::string KernelBuildInfo::ToString() const { @@ -115,6 +105,10 @@ bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const { return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_); } +bool KernelBuildInfo::IsInputDefaultPadding() const { return output_reshape_type_.empty(); } + +bool KernelBuildInfo::IsOutputDefaultPadding() const { return input_reshape_type_.empty(); } + void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) { MS_EXCEPTION_IF_NULL(kernel_build_info_); kernel_build_info_->kernel_type_ = kernel_type; diff --git a/mindspore/ccsrc/kernel/kernel_build_info.h b/mindspore/ccsrc/kernel/kernel_build_info.h index 24552e0341..76ebc7a572 100644 --- a/mindspore/ccsrc/kernel/kernel_build_info.h +++ b/mindspore/ccsrc/kernel/kernel_build_info.h @@ -54,9 +54,13 @@ class KernelBuildInfo { TypeId GetOutputDeviceType(size_t output_index) const; - bool GetInputReshapeType(size_t input_index, std::vector *reshape_type) const; + std::vector GetInputReshapeType(size_t input_index) const; - bool GetOutputReshapeType(size_t input_index, std::vector *reshape_type) const; + bool IsInputDefaultPadding() const; + + bool IsOutputDefaultPadding() const; + + std::vector GetOutputReshapeType(size_t input_index) const; std::vector GetAllInputFormats() const; diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc index 58c030e79d..490a905a45 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc @@ -18,20 +18,21 @@ #include #include "common/trans.h" #include "common/utils.h" +#include "utils/utils.h" #include "device/kernel_info.h" #include "kernel/oplib/oplib.h" #include "operator/ops.h" #include "session/anf_runtime_algorithm.h" #include "session/kernel_graph.h" #include "utils/context/ms_context.h" -#include "utils/utils.h" namespace mindspore { namespace opt { using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; namespace { -kernel::KernelBuildInfoPtr CreateKernelBuildInfo(const std::string &input_format, const std::string &output_format, - const AnfNodePtr &node, const kernel::KernelBuildInfo ori_build_info) { +kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, + const AnfNodePtr &node, + const kernel::KernelBuildInfo ori_build_info) { KernelBuildInfoBuilder builder; builder.SetInputsFormat({input_format}); builder.SetOutputsFormat({output_format}); @@ -54,9 +55,11 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, CNodePtr trans_node = func_graph->NewCNode(trans_inputs); MS_EXCEPTION_IF_NULL(trans_node); if (need_padding) { - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, - {trans::TransShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0))}, - trans_node.get()); + // if need padding we should set the transdata node's shape to the padding shape + AnfAlgo::SetOutputInferTypeAndShape( + {AnfAlgo::GetOutputInferDataType(input, 0)}, + {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), AnfAlgo::GetOutputReshapeType(input, 0))}, + trans_node.get()); } else { AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); @@ -92,9 +95,11 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, const KernelSelectPtr &kernel_select) { MS_EXCEPTION_IF_NULL(node); - bool padding_flag = false; auto input_node = AnfAlgo::GetInputNode(node, index); - if (input_node->isa() || input_node->isa()) { + auto node_with_index = AnfAlgo::VisitKernel(input_node, 0); + MS_EXCEPTION_IF_NULL(node_with_index.first); + auto real_input = node_with_index.first; + if (real_input->isa() || real_input->isa()) { input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select); MS_EXCEPTION_IF_NULL(input_node); AnfAlgo::SetNodeInput(node, input_node, index); @@ -106,33 +111,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & std::vector origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index); std::string origin_format = kOpFormat_DEFAULT; std::string dest_format = AnfAlgo::GetInputFormat(node, index); - if (dest_format == kOpFormat_C1HWNCoC0) { - padding_flag = (origin_shape.size() != kShape4dDims); - AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag, - origin_format, dest_format, kTransDataOpName, true); - MS_EXCEPTION_IF_NULL(replace_input); - return replace_input; - } - if (dest_format == kOpFormat_NC1HWC0 && origin_shape.size() > 1) { - padding_flag = (origin_shape.size() != kShape4dDims); - AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag, - origin_format, dest_format, kTransDataOpName, true); - MS_EXCEPTION_IF_NULL(replace_input); - MS_LOG(DEBUG) << "Inserted Translate45, index: " << index; - return replace_input; - } else if (dest_format == kOpFormat_FRAC_NZ) { - AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag, - origin_format, dest_format, kTransDataOpName, true); - MS_EXCEPTION_IF_NULL(replace_input); - MS_LOG(DEBUG) << "inserted translate " << AnfAlgo::GetInputFormat(node, index) << " To default, index: " << index; - return replace_input; - } else if (dest_format == kOpFormat_FRAC_Z && !origin_shape.empty()) { - padding_flag = (origin_shape.size() != kShape4dDims); - AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, index, padding_flag, - origin_format, dest_format, kTransDataOpName, true); - MS_EXCEPTION_IF_NULL(replace_input); - MS_LOG(DEBUG) << "Inserted Translate45, index: " << index; - return replace_input; + if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { + MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) + << " To DefaultFormat , index: " << index; + return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, origin_format, dest_format, kTransDataOpName, + true); } return input_node; } @@ -140,7 +123,6 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const KernelSelectPtr &kernel_select) { MS_EXCEPTION_IF_NULL(node); - bool padding_flag = false; std::string output_format; std::vector origin_shape; if (!AnfAlgo::IsRealKernel(node)) { @@ -156,46 +138,14 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An } std::string origin_format = output_format; std::string dest_format = kOpFormat_DEFAULT; - if (output_format == kOpFormat_C1HWNCoC0) { - padding_flag = (origin_shape.size() != kShape4dDims); - AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format, - dest_format, kTransDataOpName, false); - MS_EXCEPTION_IF_NULL(replace_input); - return replace_input; - } - if (output_format == kOpFormat_NC1HWC0 && origin_shape.size() > 1) { - padding_flag = (origin_shape.size() != kShape4dDims); - AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format, - dest_format, kTransDataOpName, false); - MS_EXCEPTION_IF_NULL(replace_output); - MS_LOG(DEBUG) << "Inserted Trans54"; - return replace_output; - } else if (output_format == kOpFormat_FRAC_NZ) { - AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format, - dest_format, kTransDataOpName, false); - MS_EXCEPTION_IF_NULL(replace_output); - MS_LOG(DEBUG) << "Inserted Translate " << output_format << " To default, index: 0"; - return replace_output; - } else if (output_format == kOpFormat_FRAC_Z && !origin_shape.empty()) { - padding_flag = (origin_shape.size() != kShape4dDims); - AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, padding_flag, origin_format, - dest_format, kTransDataOpName, false); - MS_EXCEPTION_IF_NULL(replace_output); - MS_LOG(DEBUG) << "Inserted Trans54"; - return replace_output; + if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { + MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0"; + return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, origin_format, dest_format, kTransDataOpName, + false); } return node; } -void GetTransDataInputFormat(const AnfNodePtr &node, size_t idx, std::string *input_format) { - MS_EXCEPTION_IF_NULL(input_format); - if (AnfAlgo::IsRealKernel(node)) { - *input_format = AnfAlgo::GetOutputFormat(node, idx); - } else { - *input_format = AnfAlgo::GetPrevNodeOutputFormat(node, 0); - } -} - AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const KernelSelectPtr &kernel_select) { MS_EXCEPTION_IF_NULL(func_graph); @@ -203,46 +153,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const std::vector make_tuple_inputs; make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) { - bool padding_flag = false; - - std::string output_format; - GetTransDataInputFormat(node, output_idx, &output_format); + std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); if (output_format == kOpFormat_NC1KHKWHWC0) { - MS_LOG(EXCEPTION) << "got the hw format" << output_format << " when insert the transdata node " + MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node " << node->DebugString(); } auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); - std::string origin_format = output_format; std::string dest_format = kOpFormat_DEFAULT; - if (output_format == kOpFormat_C1HWNCoC0) { - padding_flag = (origin_shape.size() != kShape4dDims); - AnfNodePtr replace_input = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag, - origin_format, dest_format, kTransDataOpName, false); - MS_EXCEPTION_IF_NULL(replace_input); - return replace_input; - } - if (output_format == kOpFormat_NC1HWC0 && origin_shape.size() > 1) { - padding_flag = (origin_shape.size() != kShape4dDims); - // Insert a 5to4 trans op. - AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag, - origin_format, dest_format, kTransDataOpName, false); - MS_EXCEPTION_IF_NULL(replace_output); - MS_LOG(DEBUG) << "Inserted Translate54"; - make_tuple_inputs.push_back(replace_output); - } else if (output_format == kOpFormat_FRAC_NZ) { - AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag, - origin_format, dest_format, kTransDataOpName, false); - MS_EXCEPTION_IF_NULL(replace_output); - MS_LOG(DEBUG) << "Inserted Translate " << output_format << " To default, index: " << output_idx; - make_tuple_inputs.push_back(replace_output); - } else if (output_format == kOpFormat_FRAC_Z && !origin_shape.empty()) { - padding_flag = (origin_shape.size() != kShape4dDims); - AnfNodePtr replace_output = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, padding_flag, - origin_format, dest_format, kTransDataOpName, false); - MS_EXCEPTION_IF_NULL(replace_output); - MS_LOG(DEBUG) << "Inserted Translate54"; - make_tuple_inputs.push_back(replace_output); + if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { + make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, output_format, + dest_format, kTransDataOpName, false)); } else { // No need insert trans op. make_tuple_inputs.push_back(tuple_getitem); @@ -253,16 +174,17 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const } } // namespace AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select, size_t insert_index, const bool padding_flag, + const KernelSelectPtr &kernel_select, size_t insert_index, const std::string &origin_format, const std::string &dest_format, const std::string &op_name, bool is_insert_input) { AnfNodePtr trans_node = nullptr; - AnfNodePtr input_node = nullptr; + AnfNodePtr input_node = node; AnfNodePtr trans_data = nullptr; MS_EXCEPTION_IF_NULL(node); if (origin_format.empty() || dest_format.empty()) { MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format; } + // if insert transdata for input we need to change the input if (is_insert_input) { if (!node->isa()) { MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; @@ -270,29 +192,34 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); input_node = AnfAlgo::GetInputNode(cnode, insert_index); - if (padding_flag) { - auto padd_shape = trans::TransShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0)); - auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padd_shape); - trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, padding_flag, op_name); - } else { - trans_data = NewTransOpNode(func_graph, input_node, kernel_select, padding_flag, op_name); - } + } + bool need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && + op_name == kTransDataOpName); + if (!need_padding) { + // don't need padding insert transdata only + trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); + trans_node = trans_data; + } else if (is_insert_input) { + // if need padding & is input need insert a transdata + // reshape[padding shape] -> transdata[padding shape] -> node + auto padding_shape = + trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0)); + auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); + trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, op_name); trans_node = trans_data; } else { - input_node = node; - trans_data = NewTransOpNode(func_graph, input_node, kernel_select, padding_flag, op_name); - if (padding_flag) { - auto reshape_node = - CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); - trans_node = reshape_node; - } else { - trans_node = trans_data; - } + // if need padding & is output need insert a transdata + // node -> transdata[padding shape] -> reshape[ori_shape] + trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); + auto reshape_node = + CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); + trans_node = reshape_node; } + // refresh the transdata's format to ori format & dst format MS_EXCEPTION_IF_NULL(trans_data); MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); - auto kernel_build_info = CreateKernelBuildInfo(origin_format, dest_format, input_node, *trans_ori_build_info); + auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, *trans_ori_build_info); AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get()); return trans_node; } @@ -376,7 +303,17 @@ CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnod for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { TypeId origin_type; auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); - if (!AnfAlgo::IsFeatureMapInput(cnode, input_index)) { + auto kernel_with_index = AnfAlgo::VisitKernel(cur_input, 0); + auto is_weight_boundary = [](const AnfNodePtr &node) -> bool { + if (node->isa()) { + return true; + } else if (node->isa() && AnfAlgo::IsParameterWeight(node->cast())) { + return true; + } + return false; + }; + auto real_input_node = kernel_with_index.first; + if (is_weight_boundary(real_input_node)) { // weight origin_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_index); } else { diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h index b605d700c3..8925a52a7d 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h @@ -48,7 +48,7 @@ class KernelQuery { using KernelQueryPtr = std::shared_ptr; AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select, size_t insert_index, bool padding_flag, + const KernelSelectPtr &kernel_select, size_t insert_index, const std::string &origin_format, const std::string &dest_format, const std::string &op_name, bool is_insert_input); diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc index 81e5c4b486..2d44bf8f8f 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc @@ -105,10 +105,8 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP // insert trans if (origin_format != cur_format) { auto kernel_select = std::make_shared(); - bool need_padding = - (cur_format == kOpFormat_NC1HWC0 && AnfAlgo::GetOutputInferShape(final_node, 0).size() != kShape4dDims); - final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, need_padding, cur_format, - origin_format, kTransDataOpName, false); + final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, cur_format, origin_format, + kTransDataOpName, false); final_index = 0; MS_EXCEPTION_IF_NULL(final_node); MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transdata_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transdata_split.cc index faef277599..d3990fe898 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transdata_split.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/transdata_split.cc @@ -1,99 +1,99 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/ir_fusion/transdata_split.h" -#include -#include "pre_activate/ascend/ascend_helper.h" -#include "session/anf_runtime_algorithm.h" -#include "debug/anf_ir_dump.h" - -namespace mindspore { -namespace opt { -const std::set> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, - {kOpFormat_NCHW, kOpFormat_C1HWNCoC0}, - {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, - {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; - -bool TransDataSplit::Run(const FuncGraphPtr &func_graph) { - MS_EXCEPTION_IF_NULL(func_graph); - bool changed = false; - std::vector node_list = TopoSort(func_graph->get_return()); - for (auto &node : node_list) { - if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { - CheckCNodeInputSize(node->cast(), kBackendTransDataInputNum); - if (IsFormatInvaild(node)) { - changed = DoSplit(func_graph, node); - } - } - } - return changed; -} -bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input_format = AnfAlgo::GetInputFormat(node, 0); - auto output_format = AnfAlgo::GetOutputFormat(node, 0); - auto format_pair = std::make_pair(input_format, output_format); - - return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end(); -} -// transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW) -bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto input_node = node->cast()->input(1); - MS_EXCEPTION_IF_NULL(input_node); - - auto input_format = AnfAlgo::GetInputFormat(node, 0); - auto output_format = AnfAlgo::GetOutputFormat(node, 0); - AnfNodePtr new_transdata_node = nullptr; - AnfNodePtr new_transpose_node = nullptr; - AnfNodePtr new_replace_node = nullptr; - // if output_format=default transdata need split transdata->transpose else transpose->transdata - if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { - // trans input_format to hwcn - new_transdata_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, false, input_format, kOpFormat_HWCN, - kTransDataOpName, true); - // trans hwcn to default_format - new_transpose_node = AddTransOpNodeToGraph(func_graph, new_transdata_node, kernel_select_, 0, false, kOpFormat_HWCN, - output_format, prim::kPrimTranspose->name(), false); - AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{3, 2, 0, 1}), new_transpose_node); - new_replace_node = new_transpose_node; - } else { - // trans default to hwcn - new_transpose_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, false, input_format, kOpFormat_HWCN, - prim::kPrimTranspose->name(), true); - AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{2, 3, 1, 0}), new_transpose_node); - - // trans hwcn to output_format - new_transdata_node = AddTransOpNodeToGraph(func_graph, new_transpose_node, kernel_select_, 0, false, kOpFormat_HWCN, - output_format, kTransDataOpName, false); - new_replace_node = new_transdata_node; - } - FuncGraphManagerPtr manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - manager->AddFuncGraph(func_graph); - - if (!manager->Replace(node, new_replace_node)) { - MS_LOG(EXCEPTION) << "manager replace node failed"; - } - MS_LOG(INFO) << "transdata node:" << cnode->DebugString() << "split success."; - return true; -} -} // namespace opt -} // namespace mindspore +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/ir_fusion/transdata_split.h" +#include +#include "pre_activate/ascend/ascend_helper.h" +#include "session/anf_runtime_algorithm.h" +#include "debug/anf_ir_dump.h" + +namespace mindspore { +namespace opt { +const std::set> invalid_formats_pair = {{kOpFormat_C1HWNCoC0, kOpFormat_NCHW}, + {kOpFormat_NCHW, kOpFormat_C1HWNCoC0}, + {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, + {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; + +bool TransDataSplit::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + bool changed = false; + std::vector node_list = TopoSort(func_graph->get_return()); + for (auto &node : node_list) { + if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { + CheckCNodeInputSize(node->cast(), kBackendTransDataInputNum); + if (IsFormatInvaild(node)) { + changed = DoSplit(func_graph, node); + } + } + } + return changed; +} +bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input_format = AnfAlgo::GetInputFormat(node, 0); + auto output_format = AnfAlgo::GetOutputFormat(node, 0); + auto format_pair = std::make_pair(input_format, output_format); + + return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end(); +} +// transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW) +bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto input_node = node->cast()->input(1); + MS_EXCEPTION_IF_NULL(input_node); + + auto input_format = AnfAlgo::GetInputFormat(node, 0); + auto output_format = AnfAlgo::GetOutputFormat(node, 0); + AnfNodePtr new_transdata_node = nullptr; + AnfNodePtr new_transpose_node = nullptr; + AnfNodePtr new_replace_node = nullptr; + // if output_format=default transdata need split transdata->transpose else transpose->transdata + if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { + // trans input_format to hwcn + new_transdata_node = + AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN, kTransDataOpName, true); + // trans hwcn to default_format + new_transpose_node = AddTransOpNodeToGraph(func_graph, new_transdata_node, kernel_select_, 0, kOpFormat_HWCN, + output_format, prim::kPrimTranspose->name(), false); + AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{3, 2, 0, 1}), new_transpose_node); + new_replace_node = new_transpose_node; + } else { + // trans default to hwcn + new_transpose_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN, + prim::kPrimTranspose->name(), true); + AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{2, 3, 1, 0}), new_transpose_node); + + // trans hwcn to output_format + new_transdata_node = AddTransOpNodeToGraph(func_graph, new_transpose_node, kernel_select_, 0, kOpFormat_HWCN, + output_format, kTransDataOpName, false); + new_replace_node = new_transdata_node; + } + FuncGraphManagerPtr manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->AddFuncGraph(func_graph); + + if (!manager->Replace(node, new_replace_node)) { + MS_LOG(EXCEPTION) << "Manager replace node failed"; + } + MS_LOG(INFO) << "Transdata node:" << cnode->DebugString() << "split success."; + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 893c379a07..3f20fec7b5 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -289,6 +289,11 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); + if (output_idx > GetOutputTensorNum(node)) { + MS_LOG(EXCEPTION) << "Output index:" << output_idx + << " is out of the node output range :" << GetOutputTensorNum(node) << " #node [" + << node->DebugString() << "]"; + } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); @@ -298,6 +303,11 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) { MS_EXCEPTION_IF_NULL(node); + if (input_idx > GetInputTensorNum(node)) { + MS_LOG(EXCEPTION) << "Input index :" << input_idx + << " is out of the number node Input range :" << GetInputTensorNum(node) << "#node [" + << node->DebugString() << "]"; + } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); @@ -362,62 +372,60 @@ std::vector AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNo std::vector AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) { auto format = GetOutputFormat(node, output_idx); auto infer_shape = GetOutputInferShape(node, output_idx); - // if format is default_format or NC1KHKWHWC0,device shape = original shape - if (format == kOpFormat_DEFAULT || format == kOpFormat_NC1KHKWHWC0) { - return infer_shape; - } - // scalar shape if (infer_shape.empty()) { return infer_shape; } - if (format == kOpFormat_FRAC_NZ) { - return trans::TransShapeToDevice(infer_shape, format); + // if format is default_format or NC1KHKWHWC0,device shape = original shape + if (trans::IsNeedPadding(format, infer_shape.size())) { + infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx)); } - // else trans infer shape to 4d and then calculate device shape - return trans::TransShapeToDevice(trans::TransShapeTo4d(infer_shape), format); + return trans::TransShapeToDevice(infer_shape, format); } std::vector AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) { auto format = GetInputFormat(node, input_idx); auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx); - // if format is default_format or NC1KHKWHWC0,device shape = original shape - if (format == kOpFormat_DEFAULT || format == kOpFormat_NC1KHKWHWC0) { - return infer_shape; - } if (infer_shape.empty()) { return infer_shape; } - if (format == kOpFormat_FRAC_NZ) { - return trans::TransShapeToDevice(infer_shape, format); + // if format is default_format or NC1KHKWHWC0,device shape = original shape + if (trans::IsNeedPadding(format, infer_shape.size())) { + infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx)); } - // else trans infer shape to 4d and then calculate device shape - return trans::TransShapeToDevice(trans::TransShapeTo4d(infer_shape), format); + return trans::TransShapeToDevice(infer_shape, format); } std::vector AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) { MS_EXCEPTION_IF_NULL(node); + if (input_idx > GetInputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index:" << input_idx + << " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node[" + << node->DebugString() << "]"; + } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); - std::vector result; - if (!build_info->GetInputReshapeType(input_idx, &result)) { - MS_LOG(EXCEPTION) << "Failed to get the node's[ " << node->DebugString() << "] reshape type !"; + if (build_info->IsInputDefaultPadding()) { + return {}; } - return result; + return build_info->GetInputReshapeType(input_idx); } std::vector AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); + if (output_idx > GetOutputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " + << GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]"; + } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); - std::vector result; - if (!build_info->GetOutputReshapeType(output_idx, &result)) { - MS_LOG(EXCEPTION) << "Failed to get the node's[ " << node->DebugString() << "] reshape type !"; + if (build_info->IsOutputDefaultPadding()) { + return {}; } - return result; + return build_info->GetOutputReshapeType(output_idx); } TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) { @@ -463,6 +471,10 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &nod TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); + if (output_idx > GetOutputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " + << GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]"; + } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); @@ -472,6 +484,10 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) { MS_EXCEPTION_IF_NULL(node); + if (input_idx > GetInputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ " + << GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]"; + } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); @@ -496,11 +512,15 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"; } } + if (output_idx > GetOutputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " + << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; + } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto addr = kernel_info->GetOutputAddr(output_idx); if (addr == nullptr) { - MS_LOG(EXCEPTION) << "output_idx " << output_idx << " of node " << node->DebugString() + MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString() << " output addr is not exist"; } return addr; @@ -517,11 +537,15 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."; } } + if (output_idx > GetOutputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " + << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; + } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); auto addr = kernel_info->GetMutableOutputAddr(output_idx); if (addr == nullptr) { - MS_LOG(EXCEPTION) << "output_idx" << output_idx << " of node " << node->DebugString() + MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString() << " output addr is not exist"; } return addr; @@ -530,6 +554,10 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod // get output device addr of anf_node bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); + if (output_idx > GetOutputTensorNum(node)) { + MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " + << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; + } auto kernel_info = node->kernel_info(); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->OutputAddrExist(output_idx); @@ -769,22 +797,24 @@ AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) return node->input(get_input_index); } +bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + return false; + } + auto kernel_info = node->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + return kernel_info->is_feature_map(); +} + bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) { if (!node->isa()) { - MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature"; + MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map"; } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); auto input_node = cnode->input(input_index + 1); - auto node_with_index = VisitKernel(input_node, 0); - MS_EXCEPTION_IF_NULL(node_with_index.first); - if (node_with_index.first->isa()) { - return false; - } - if (node_with_index.first->isa()) { - return !AnfAlgo::IsParameterWeight(node_with_index.first->cast()); - } - return true; + return IsFeatureMapOutput(input_node); } size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) { diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index 1a1d471b84..9ac83e011f 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -101,7 +101,9 @@ class AnfRuntimeAlgorithm { static std::vector GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx); // get input shapes which will built and run in device static std::vector GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx); + // Get Input Padding Axis static std::vector GetInputReshapeType(const AnfNodePtr &node, size_t output_idx); + // Get Output Padding Axis static std::vector GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx); // get output data type inferred by ME of anf node static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx); @@ -165,6 +167,9 @@ class AnfRuntimeAlgorithm { // get graph id static uint32_t GetGraphId(const AnfNode *node); static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index); + // charge if the node's output is a feature map output + static bool IsFeatureMapOutput(const AnfNodePtr &node); + // charge if the node's input is from a feature map output static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index); // get real input index for some tbe ops which input order is different between me and tbe impl static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index); diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 751cf76e32..93ae99f4d2 100755 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -18,6 +18,7 @@ #include "operator/ops.h" #include "ir/meta_tensor.h" #include "ir/anf.h" +#include "common/trans.h" #include "device/kernel_runtime.h" #include "device/ascend/kernel_select_ascend.h" #include "device/ascend/kernel_build_ascend.h" @@ -730,8 +731,8 @@ void AscendSession::SetChildGraphParameter(const tensor::TensorPtr &front_tensor size_t tensor_size = front_tensor->data().nbytes(); auto addr = AnfAlgo::GetOutputAddr(backend_parameter, 0); MS_EXCEPTION_IF_NULL(addr); - if (!addr->SyncHostToDevice(front_tensor->shape(), tensor_size, front_tensor->data_type(), - front_tensor->data_c(false))) { + if (!addr->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_parameter, 0), tensor_size, + front_tensor->data_type(), front_tensor->data_c(false))) { MS_LOG(EXCEPTION) << "Tensor SyncHostToDevice fail!"; } MS_LOG(INFO) << "Finish!"; diff --git a/mindspore/ccsrc/session/kernel_graph.cc b/mindspore/ccsrc/session/kernel_graph.cc index dbf6e07e7e..bbcc04e14b 100755 --- a/mindspore/ccsrc/session/kernel_graph.cc +++ b/mindspore/ccsrc/session/kernel_graph.cc @@ -143,6 +143,12 @@ CNodePtr KernelGraph::NewCNode(const std::vector &inputs) { cnode->set_abstract(std::make_shared()); // create kernel_info from new parameter auto kernel_info = std::make_shared(); + // if the node only has the primitive(such as getNext) or the node's input has a feature map input + // then the node's output is a feature map output + if (inputs.size() == 1 || std::any_of(inputs.begin() + 1, inputs.end(), + [&](const AnfNodePtr &node) { return AnfAlgo::IsFeatureMapOutput(node); })) { + kernel_info->SetFeatureMapFlag(true); + } cnode->set_kernel_info(kernel_info); AnfAlgo::SetGraphId(graph_id_, cnode.get()); return cnode; @@ -162,22 +168,26 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) { ParameterPtr KernelGraph::NewParameter(const ParameterPtr ¶meter) { ParameterPtr new_parameter = add_parameter(); MS_EXCEPTION_IF_NULL(new_parameter); + // create kernel_info form new parameter + auto kernel_info = std::make_shared(); size_t output_tensor_num = 1; // if use default parameter = nullptr,it remarks create a new parameter from no parameter if (parameter == nullptr) { new_parameter->set_abstract(std::make_shared()); + kernel_info->SetFeatureMapFlag(true); } else { // if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter new_parameter->set_abstract(parameter->abstract()); new_parameter->set_name(parameter->name()); - if (parameter->has_default()) { + if (AnfAlgo::IsParameterWeight(parameter)) { new_parameter->set_default_param(parameter->default_param()); + kernel_info->SetFeatureMapFlag(false); + } else { + kernel_info->SetFeatureMapFlag(true); } // if output is a tuple tensor,now can use for loop to handle tuple tensor output_tensor_num = AnfAlgo::GetOutputTensorNum(parameter); } - // create kernel_info form new parameter - auto kernel_info = std::make_shared(); new_parameter->set_kernel_info(kernel_info); // create kernel_build_info for new parameter auto kernel_build_info_builder = std::make_shared(); @@ -217,6 +227,7 @@ std::vector KernelGraph::SplitTupleValueNodeToNodeList(const ValueNo AddValueNodeToGraph(new_value_node); auto kernel_info = std::make_shared(); new_value_node->set_kernel_info(kernel_info); + kernel_info->SetFeatureMapFlag(false); // create kernel_build_info for new value node auto kernel_build_info_builder = std::make_shared(); // set the format of value_node to DEFAULT_FORMAT @@ -240,6 +251,7 @@ ValueNodePtr KernelGraph::NewValueNode(const ValueNodePtr &value_node) { new_value_node->set_abstract(value_node->abstract()); // create kernel_info fo new value node auto kernel_info = std::make_shared(); + kernel_info->SetFeatureMapFlag(false); new_value_node->set_kernel_info(kernel_info); // create kernel_build_info for new value node auto kernel_build_info_builder = std::make_shared(); diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index d2a255229d..bea51037bf 100755 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -20,6 +20,7 @@ #include "pipeline/parse/data_converter.h" #include "ir/manager.h" #include "operator/ops.h" +#include "common/trans.h" #include "utils/context/ms_context.h" #include "utils/config_manager.h" #include "session/anf_runtime_algorithm.h" @@ -124,7 +125,8 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne MS_EXCEPTION_IF_NULL(ms_context); if (ms_context->enable_pynative_infer()) { tensor->set_device_address(AnfAlgo::GetMutableOutputAddr(node, output_index)); - } else if (!address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), + } else if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, output_index), + LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c(true))) { MS_LOG(INFO) << "output sync device to host error!!!"; tensor->set_dirty(false); @@ -369,7 +371,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr &graph, kernel_build_info_builder->SetOutputsDeviceType(std::vector{input_tensor->device_address()->type_id()}); } AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get()); - // construct abstract of parameter + // ftruct abstract of parameter auto abstract = std::make_shared(input_tensor); param->set_abstract(abstract); return param; @@ -548,7 +550,8 @@ void SessionBasic::LoadInputData(const std::shared_ptr &kernel_grap if (need_sync) { tensor->set_device_address(device_address); MS_EXCEPTION_IF_NULL(device_address); - if (!device_address->SyncHostToDevice(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), + if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), + LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c(false))) { MS_LOG(EXCEPTION) << "SyncHostToDevice failed."; } @@ -620,8 +623,8 @@ void SessionBasic::Summary(KernelGraph *graph) { (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); MS_EXCEPTION_IF_NULL(address); - if (!address->SyncDeviceToHost(tensor->shape(), LongToSize(tensor->data().nbytes()), tensor->data_type(), - tensor->data_c(true))) { + if (!address->SyncDeviceToHost(trans::GetRuntimePaddingShape(node, index), LongToSize(tensor->data().nbytes()), + tensor->data_type(), tensor->data_c(true))) { MS_LOG(ERROR) << "Failed to sync output from device to host."; } tensor->set_dirty(false); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 39b4b7a160..79a4b216fb 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -197,8 +197,8 @@ const std::set kOptOperatorSet = { kApplyRMSPropOpName, }; -const std::set kSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, - kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0}; +const std::set kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, + kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0}; static inline void ChangeFileMode(const std::string& file_name, mode_t mode) { if (access(file_name.c_str(), F_OK) != 0) { diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion_test.cc index e7831ec353..44b9b3df69 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion_test.cc @@ -80,6 +80,8 @@ TEST_F(TestHWLayerNormBetaGammaBackpropFusion, layernorm_beta_gamma_backprop_fus builder1.SetOutputsDeviceType({kNumberTypeFloat32}); cast0->set_kernel_info(std::make_shared()); cast1->set_kernel_info(std::make_shared()); + cast0->set_abstract(x_abstract); + cast1->set_abstract(x_abstract); AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast0.get()); AnfAlgo::SetSelectKernelBuildInfo(builder1.Build(), cast1.get()); diff --git a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc index 2af2a7413b..6375d1a758 100644 --- a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc +++ b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc @@ -211,8 +211,8 @@ TEST_F(AnfRuntimeAlgorithmTest, EraseNodeAttr) { TEST_F(AnfRuntimeAlgorithmTest, GetInputTensorNum) { auto kernel_graph = std::make_shared(); // test cnode node - auto parameter_one = kernel_graph->add_parameter(); - auto parameter_two = kernel_graph->add_parameter(); + auto parameter_one = kernel_graph->NewParameter(); + auto parameter_two = kernel_graph->NewParameter(); std::vector add_inputs{NewValueNode(prim::kPrimTensorAdd), parameter_one, parameter_two}; auto add = kernel_graph->NewCNode(add_inputs); EXPECT_EQ(AnfAlgo::GetInputTensorNum(add), 2); @@ -247,9 +247,11 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputTensorNum) { TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) { auto kernel_graph = std::make_shared(); - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimTensorAdd)); + std::vector inputs = {NewValueNode(prim::kPrimTensorAdd), kernel_graph->NewParameter(), + kernel_graph->NewParameter()}; auto add = kernel_graph->NewCNode(inputs); + std::vector shape = {1, 2, 3, 4}; + AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, {shape, shape}, add.get()); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); auto d_kernel_info = add->kernel_info(); @@ -266,8 +268,8 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) { TEST_F(AnfRuntimeAlgorithmTest, GetInputFormat) { auto kernel_graph = std::make_shared(); - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimTensorAdd)); + std::vector inputs = {NewValueNode(prim::kPrimTensorAdd), kernel_graph->NewParameter(), + kernel_graph->NewParameter()}; auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); @@ -345,7 +347,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputInferShape) { std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); // test parameter node as input - auto parameter_node = kernel_graph->add_parameter(); + auto parameter_node = kernel_graph->NewParameter(); MS_EXCEPTION_IF_NULL(parameter_node); parameter_node->set_abstract(x_abstract); EXPECT_THROW(AnfAlgo::GetPrevNodeOutputInferShape(parameter_node, 0), std::runtime_error); @@ -387,13 +389,13 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceShape) { auto kernel_graph = std::make_shared(); std::vector shp{2, 32, 224, 224}; auto x_abstract = std::make_shared(kFloat32, shp); - auto parameter_one = kernel_graph->add_parameter(); + auto parameter_one = kernel_graph->NewParameter(); MS_EXCEPTION_IF_NULL(parameter_one); parameter_one->set_abstract(x_abstract); - auto parameter_two = kernel_graph->add_parameter(); + auto parameter_two = kernel_graph->NewParameter(); MS_EXCEPTION_IF_NULL(parameter_two); parameter_two->set_abstract(x_abstract); - auto parameter_third = kernel_graph->add_parameter(); + auto parameter_third = kernel_graph->NewParameter(); MS_EXCEPTION_IF_NULL(parameter_third); parameter_third->set_abstract(x_abstract); // test cnode as input @@ -466,8 +468,8 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceDataTypeTest) { TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceDataTypeTest) { auto kernel_graph = std::make_shared(); - std::vector inputs; - inputs.push_back(NewValueNode(prim::kPrimTensorAdd)); + std::vector inputs = {NewValueNode(prim::kPrimTensorAdd), kernel_graph->NewParameter(), + kernel_graph->NewParameter()}; auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); diff --git a/tests/ut/cpp/session/kernel_graph_test.cc b/tests/ut/cpp/session/kernel_graph_test.cc index 55e1b1b28e..a62af9c892 100644 --- a/tests/ut/cpp/session/kernel_graph_test.cc +++ b/tests/ut/cpp/session/kernel_graph_test.cc @@ -140,11 +140,11 @@ TEST_F(KernelGraphTest, SetExecOrderByDefault) { std::vector shape = {2, 32, 224, 224}; auto abstract = std::make_shared(kFloat32, shape); - auto x_parameter = kernel_graph->add_parameter(); + auto x_parameter = kernel_graph->NewParameter(); MS_EXCEPTION_IF_NULL(x_parameter); x_parameter->set_name("x_parameter"); x_parameter->set_abstract(abstract); - auto y_parameter = kernel_graph->add_parameter(); + auto y_parameter = kernel_graph->NewParameter(); MS_EXCEPTION_IF_NULL(y_parameter); y_parameter->set_name("y_parameter"); y_parameter->set_abstract(abstract); @@ -153,7 +153,7 @@ TEST_F(KernelGraphTest, SetExecOrderByDefault) { MS_EXCEPTION_IF_NULL(add); add->set_abstract(abstract); - auto z_parameter = kernel_graph->add_parameter(); + auto z_parameter = kernel_graph->NewParameter(); MS_EXCEPTION_IF_NULL(z_parameter); z_parameter->set_name("z_parameter"); z_parameter->set_abstract(abstract);