diff --git a/CMakeLists.txt b/CMakeLists.txt index 4e3580adc5..33f4ef35a8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.14.1) +cmake_minimum_required(VERSION 3.14.0) project(MindSpore) if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0) @@ -14,18 +14,25 @@ endif() if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin") set(CMAKE_OSX_SYSROOT "") - set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Winconsistent-missing-override -Wuser-defined-warnings -Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare -Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move -Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2") + set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Winconsistent-missing-override -Wuser-defined-warnings \ + -Wno-return-std-move -Wno-unused-private-field -Wno-unused-lambda-capture -Wno-sign-compare \ + -Wno-overloaded-virtual -Wno-unneeded-internal-declaration -Wno-unused-variable -Wno-pessimizing-move \ + -Wno-inconsistent-missing-override -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2") else() - set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2") + set(CMAKE_CXX_FLAGS_RELEASE "$ENV{CXXFLAGS} -O2 -Wl,--allow-shlib-undefined \ + -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2") endif() if(ENABLE_PYTHON) add_compile_definitions(ENABLE_PYTHON) endif() -set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer -Wl,--allow-shlib-undefined -D_LIBCPP_INLINE_VISIBILITY='' -D_LIBCPP_DISABLE_EXTERN_TEMPLATE=1 -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2 -Wno-cpp") +set(CMAKE_CXX_FLAGS_DEBUG "$ENV{CXXFLAGS} -O0 -g2 -ggdb -fno-inline-functions -fno-omit-frame-pointer \ + -Wl,--allow-shlib-undefined -D_LIBCPP_INLINE_VISIBILITY='' -D_LIBCPP_DISABLE_EXTERN_TEMPLATE=1 \ + -DHALF_ENABLE_CPP11_USER_LITERALS=0 -D_FORTIFY_SOURCE=2 -Wno-cpp") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/include -std=c++17 -Werror -Wall -Wno-deprecated-declarations -fPIC") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -I/usr/local/include -std=c++17 \ + -Werror -Wall -Wno-deprecated-declarations -fPIC") set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(PYBIND11_CPP_STANDARD -std=c++17) diff --git a/mindspore/_extends/builtin_operations.py b/mindspore/_extends/builtin_operations.py index 3a49fc01f1..3722878876 100644 --- a/mindspore/_extends/builtin_operations.py +++ b/mindspore/_extends/builtin_operations.py @@ -132,6 +132,16 @@ def Depend(value, expr): return value +def UpdateState(monad, expr): + """Implement `UpdateState`.""" + return monad + + +def Load(value, u=None): + """Implement `Load`.""" + return value + + # only used in PyNative mode def make_ref(key, value, ref): return value diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc index 7ae84f7441..70e6498a7b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc @@ -42,14 +42,16 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector inputs_format{}; std::vector inputs_type{}; if (op_name == kPrint || op_name == kPack || op_name == kMeshgrid) { - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_format.emplace_back(kOpFormat_DEFAULT); inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); } } std::vector outputs_format; std::vector outputs_type; - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { outputs_format.emplace_back(kOpFormat_DEFAULT); outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc index 77e1b295b5..aa95aea562 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/common_utils.cc @@ -139,9 +139,9 @@ bool CheckCache(const std::string &kernel_name) { std::string kernel_json = bin_map->Search(kernel_name); bool ret = (!kernel_json.empty()); if (ret) { - MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registed."; + MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registered."; } else { - MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registed."; + MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registered."; } return ret; } @@ -730,30 +730,6 @@ bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann: return false; } -void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector> *node_list) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node_list); - auto output = func_graph->output(); - MS_EXCEPTION_IF_NULL(output); - if (AnfAlgo::IsRealKernel(output)) { - // single output. - node_list->push_back(std::make_pair(output, 0)); - return; - } else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) { - auto output_cnode = output->cast(); - MS_EXCEPTION_IF_NULL(output_cnode); - // multi output. - auto &inputs = output_cnode->inputs(); - for (size_t i = 1; i < inputs.size(); ++i) { - auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0); - node_list->push_back(in_with_idx); - } - return; - } - MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2) - << " of graph: " << func_graph->ToString(); -} - bool IsWeightBoundary(const AnfNodePtr &node) { if (node->isa()) { return true; @@ -776,7 +752,7 @@ std::vector GetReduceAttrAxis(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(primitive); auto axis_attr = primitive->GetAttr(kAxis); if (axis_attr == nullptr) { - MS_LOG(ERROR) << "This node does't have axie attr."; + MS_LOG(ERROR) << "This node doesn't have axie attr."; return std::vector(); } std::vector axis_list; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.cc index 2a47514925..00c18d8808 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/map_cache_idx_cpu_kernel.cc @@ -181,7 +181,8 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector &inputs, std::vector out_shape; out_shape.emplace_back(miss_count); std::vector dtypes; - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node_); i++) { + size_t output_num = AnfAlgo::GetOutputTensorNum(node_); + for (size_t i = 0; i < output_num; i++) { dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); } AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetOutputInferShape(node_, 0), out_shape, out_shape, out_shape}, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_and_filter_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_and_filter_cpu_kernel.cc index f3090dbe71..6a8152e86b 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_and_filter_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/sub_and_filter_cpu_kernel.cc @@ -69,7 +69,8 @@ void SubAndFilterCPUKernel::LaunchKernel(const std::vector &inputs, std::vector out_shape; out_shape.emplace_back(count); std::vector dtypes; - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node_); i++) { + size_t output_num = AnfAlgo::GetOutputTensorNum(node_); + for (size_t i = 0; i < output_num; i++) { dtypes.push_back(AnfAlgo::GetOutputInferDataType(node_, i)); } AnfAlgo::SetOutputInferTypeAndShape(dtypes, {out_shape, out_shape}, node_.get()); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.cc index 4e07463a6c..b17a55c37c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/other/assign_gpu_kernel.cc @@ -29,5 +29,8 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( Assign, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), AssignGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( + Assign, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + AssignGpuKernel, int64_t) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc index 52582cf568..736258505e 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hccl_kernel_metadata.cc @@ -63,13 +63,15 @@ void HcclMetadataInfo(const CNodePtr &kernel_node, std::vector inputs_format{}; std::vector inputs_type{}; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_format.emplace_back(GetKernelFormat(kernel_node, input_index)); inputs_type.push_back(type); } std::vector outputs_format; std::vector outputs_type; - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { if (op_name == kReduceScatter && AnfAlgo::GetNodeAttr(kernel_node, kAttrFusion) > 0) { outputs_format.emplace_back(GetKernelFormat(kernel_node, 0)); } else { diff --git a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc index c20ffc5882..7c7bc443a8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/hccl/hcom_util.cc @@ -31,7 +31,8 @@ bool IsPyNativeMode() { bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector> *hccl_kernel_intput_shape_list) { MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(hccl_kernel_intput_shape_list); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { + size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); + for (size_t i = 0; i < input_num; ++i) { std::vector shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i); hccl_kernel_intput_shape_list->emplace_back(shape_i); } @@ -42,7 +43,8 @@ bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector> *hccl_kernel_output_shape_list) { MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(hccl_kernel_output_shape_list); - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf_node); ++i) { + size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); + for (size_t i = 0; i < output_num; ++i) { std::vector shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); hccl_kernel_output_shape_list->emplace_back(shape_i); } @@ -53,11 +55,12 @@ bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector *data_type_list) { MS_EXCEPTION_IF_NULL(anf_node); MS_EXCEPTION_IF_NULL(data_type_list); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node); ++i) { + size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); + for (size_t i = 0; i < input_num; ++i) { auto type_ptr = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, i); auto iter = CONST_OP_HCOM_DATA_TYPE_MAP.find(type_ptr); if (iter == CONST_OP_HCOM_DATA_TYPE_MAP.end()) { - MS_LOG(EXCEPTION) << "HcomDataType cann't support Current Ascend Data Type : " << type_ptr; + MS_LOG(EXCEPTION) << "HcomDataType can't support Current Ascend Data Type : " << type_ptr; } data_type_list->emplace_back(iter->second); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_metadata.cc index d2a04e4160..5ec112dfe3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_metadata.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/host/host_kernel_metadata.cc @@ -37,13 +37,15 @@ void HostMetadataInfo(const CNodePtr &kernel_node, std::vector inputs_format{}; std::vector inputs_type{}; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_format.emplace_back(kOpFormat_DEFAULT); inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); } std::vector outputs_format; std::vector outputs_type; - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { outputs_format.emplace_back(kOpFormat_DEFAULT); outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc index b534fd112b..2932e90064 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc @@ -30,7 +30,7 @@ std::string KernelBuildInfo::GetInputFormat(size_t input_index) const { std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const { if (output_index >= outputs_format_.size()) { - MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of input node"; + MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output"; return kInvalidFormat; } return outputs_format_[output_index]; diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc index e5315efc6b..5f2be915d7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/label_switch.cc @@ -86,6 +86,9 @@ std::vector> LabelSwitchDesc::GetKernel builder.SetProcessor(AICORE); builder.SetKernelType(RT_KERNEL); builder.SetFusionType(OPAQUE); + // LabelSwitch always return UMonad. + builder.SetOutputsFormat({kOpFormat_DEFAULT}); + builder.SetOutputsDeviceType({TypeId::kObjectTypeUMonad}); label_switch_build_info.emplace_back(builder.Build()); } return label_switch_build_info; diff --git a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc index 59ac61fd81..47fca0f44e 100755 --- a/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/rts/rt_kernel_info.cc @@ -74,11 +74,10 @@ void GetRtKelInfo(const CNodePtr &kernel_node, input_types.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, i)); } kernel_build_info_builder->SetInputsDeviceType(input_types); - // set output info - auto output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - kernel_build_info_builder->SetOutputsFormat(std::vector(output_num, kOpFormat_DEFAULT)); - kernel_build_info_builder->SetOutputsDeviceType(std::vector(output_num, TypeId::kTypeUnknown)); - // set ohter info + // Kernel ops in while-list such as 'LabelSet' always return UMonad. + kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT}); + kernel_build_info_builder->SetOutputsDeviceType({TypeId::kObjectTypeUMonad}); + // set other info kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE); kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE); kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL); diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc index 1c52b3eade..6c6ab73335 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_build.cc @@ -1052,10 +1052,16 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i auto node_name = AnfAlgo::GetCNodeName(cnode); auto op_info = tbe::TbeDynamicShapeUtil::FindOp(node_name, cnode); MS_EXCEPTION_IF_NULL(cnode); - if (op_info->inputs_ptr().size() < (cnode->inputs().size() - 1)) { + auto node_inputs_size = cnode->inputs().size(); + for (auto &input : cnode->inputs()) { + if (HasAbstractMonad(input)) { + node_inputs_size--; + } + } + if (op_info->inputs_ptr().size() < (node_inputs_size - 1)) { MS_EXCEPTION(ArgumentError) << "op info error, node name:" << cnode->fullname_with_scope(); } - return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size()); + return (op_info->inputs_ptr().size() + 1 - node_inputs_size); } std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) { @@ -1103,6 +1109,9 @@ bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, bool is_dynamic_input = IsDynamicInput(cnode); for (size_t i = 1; i < cnode->inputs().size(); ++i) { auto input = cnode->input(i); + if (HasAbstractMonad(input)) { + continue; + } auto kernel_idx = AnfAlgo::VisitKernel(input, 0); auto real_node = kernel_idx.first; size_t real_idx = kernel_idx.second; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index a42539e1dd..802a40341c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -112,6 +112,10 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & const KernelSelectPtr &kernel_select) { MS_EXCEPTION_IF_NULL(node); auto input_node = AnfAlgo::GetInputNode(node, index); + if (HasAbstractMonad(input_node)) { + // No transfer for monad inputs. + return input_node; + } auto node_with_index = AnfAlgo::VisitKernel(input_node, 0); MS_EXCEPTION_IF_NULL(node_with_index.first); auto real_input = node_with_index.first; @@ -330,8 +334,9 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; - size_t in_num = AnfAlgo::GetInputTensorNum(cnode); + size_t in_num = AnfAlgo::GetInputNum(cnode); // include monads. for (size_t input_index = 0; input_index < in_num; ++input_index) { + // Monad inputs keep unchanged from GetTransInputNodePtr(). AnfNodePtr input_node = GetTransInputNodePtr(func_graph, cnode, input_index, kernel_select); MS_EXCEPTION_IF_NULL(input_node); new_inputs.push_back(input_node); @@ -352,12 +357,18 @@ AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePt CNodePtr InsertCastForInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; - size_t in_num = AnfAlgo::GetInputTensorNum(cnode); + size_t in_num = AnfAlgo::GetInputNum(cnode); // include monads. for (size_t input_index = 0; input_index < in_num; ++input_index) { + auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); + if (HasAbstractMonad(cur_input)) { + // No cast for monad inputs. + new_inputs.push_back(cur_input); + continue; + } auto prev_node = AnfAlgo::GetPrevNodeOutput(cnode, input_index); const auto infer_type = AnfAlgo::GetOutputInferDataType(prev_node.first, prev_node.second); TypeId origin_type(kTypeUnknown); - auto cur_input = AnfAlgo::GetInputNode(cnode, input_index); + auto kernel_with_index = AnfAlgo::VisitKernelWithReturnType(cur_input, 0); auto real_input_node = kernel_with_index.first; if (kernel::IsWeightBoundary(real_input_node) || func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc index 60b237da23..6457098627 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc @@ -244,7 +244,9 @@ void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph, if (auto in = cnode->input(idx); std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(), (*buffer_fusion_infos)[fusion_id].inputs_list.end(), in) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) { - (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(in); + if (!HasAbstractMonad(in)) { + (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(in); + } } } } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc index a7d3626b72..fc34fc09b3 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_memcpy_async_for_hccl_op.cc @@ -40,62 +40,6 @@ bool IsParameterOrValueNode(const AnfNodePtr &node) { return real_node->isa(); } -void SetInput(const CNodePtr &control_depend, const int index, const FuncGraphPtr &graph, const CNodePtr &hccl_node, - const std::vector &memcpy_async_list) { - MS_EXCEPTION_IF_NULL(control_depend); - MS_EXCEPTION_IF_NULL(graph); - MS_EXCEPTION_IF_NULL(hccl_node); - std::vector make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)}; - make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end()); - make_tuple_inputs.emplace_back(hccl_node); - auto make_tuple = graph->NewCNode(make_tuple_inputs); - MS_EXCEPTION_IF_NULL(make_tuple); - control_depend->set_input(IntToSize(index), make_tuple); -} - -void DealControlForGetitem(const CNodePtr &tuple_getitem, const FuncGraphPtr &graph, const CNodePtr &hccl_node, - const std::vector &memcpy_async_list) { - MS_EXCEPTION_IF_NULL(tuple_getitem); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto &node_users = manager->node_users(); - auto iter = node_users.find(tuple_getitem); - if (iter == node_users.end()) { - MS_LOG(EXCEPTION) << "node has no output in manager" - << " trace: " << trace::DumpSourceLines(hccl_node); - } - for (const auto &node_index : iter->second) { - AnfNodePtr output = node_index.first; - MS_EXCEPTION_IF_NULL(output); - if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { - SetInput(output->cast(), node_index.second, graph, hccl_node, memcpy_async_list); - } - } -} - -void TransferControl(const CNodePtr &hccl_node, const std::vector &memcpy_async_list, - const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(hccl_node); - MS_EXCEPTION_IF_NULL(graph); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto &node_users = manager->node_users(); - auto iter = node_users.find(hccl_node); - if (iter == node_users.end()) { - MS_LOG(EXCEPTION) << "node has no output in manager" - << " trace: " << trace::DumpSourceLines(hccl_node); - } - // find hccl_node's output which is a control depend - for (const auto &node_index : iter->second) { - AnfNodePtr output = node_index.first; - MS_EXCEPTION_IF_NULL(output); - if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) { - SetInput(output->cast(), node_index.second, graph, hccl_node, memcpy_async_list); - } else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimTupleGetItem)) { - DealControlForGetitem(output->cast(), graph, hccl_node, memcpy_async_list); - } - } -} // NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i) bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users) { if (node_users.size() == 1) { @@ -155,7 +99,7 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(hccl_node); - std::vector memcpy_async_list; + bool need_memcpy_async = false; std::vector new_inputs = {hccl_node->input(0)}; for (size_t i = 1; i < hccl_node->size(); ++i) { auto input = hccl_node->input(i); @@ -164,17 +108,17 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co if (memcpy_async == nullptr) { MS_LOG(EXCEPTION) << "Create memcpy_async op failed."; } - if (AnfAlgo::IsNodeDynamicShape(input)) { + if (input->isa() && AnfAlgo::IsNodeDynamicShape(input)) { AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), memcpy_async); } new_inputs.push_back(memcpy_async); - memcpy_async_list.push_back(memcpy_async); + need_memcpy_async = true; } else { new_inputs.push_back(input); } } - if (!memcpy_async_list.empty()) { + if (need_memcpy_async) { CNodePtr new_hccl_node = std::make_shared(*hccl_node); new_hccl_node->set_inputs(new_inputs); auto manager = graph->manager(); @@ -182,9 +126,6 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node"; (void)manager->Replace(hccl_node, new_hccl_node); MS_LOG(DEBUG) << "end replace"; - - // transer hccl op's control to the memcpy_async - TransferControl(new_hccl_node, memcpy_async_list, graph); } } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.cc index c06a394f3c..1939c88abe 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_pad_for_nms_with_mask.cc @@ -57,7 +57,7 @@ const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph return nullptr; } std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; - for (size_t input_idx = 0; input_idx < AnfAlgo::GetInputTensorNum(cnode); input_idx++) { + for (size_t input_idx = 0; input_idx < input_num; input_idx++) { auto cur_input = AnfAlgo::GetInputNode(cnode, input_idx); auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx); auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_idx); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.cc index 552efb5e48..4eb055829f 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.cc @@ -40,7 +40,8 @@ const AnfNodePtr ConvertCastFormat::Process(const FuncGraphPtr &func_graph, cons } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t input_index = 0; input_index < input_num; ++input_index) { auto input_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cnode, input_index), 0).first; MS_EXCEPTION_IF_NULL(input_node); if (!input_node->isa()) { @@ -77,7 +78,8 @@ void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGr MS_EXCEPTION_IF_NULL(node_info.first); auto cast_out_node = node_info.first->cast(); MS_EXCEPTION_IF_NULL(cast_out_node); - for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cast_out_node); ++index) { + size_t input_num = AnfAlgo::GetInputTensorNum(cast_out_node); + for (size_t index = 0; index < input_num; ++index) { if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast(), index), 0).first != cast_node) { continue; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc index 20b900da45..70bea4a025 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/deal_ref_trans_and_cast.cc @@ -162,7 +162,8 @@ CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_ std::vector make_tuple_inputs; AbstractBasePtrList abstract_list; make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); + for (size_t output_index = 0; output_index < output_num; ++output_index) { CNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index); // deal with ref output if (ref_infos.count(output_index) != 0) { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.cc index 7c3dfa1f06..87818f2e8d 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/dynamic_rnn_grad_reformat.cc @@ -37,7 +37,7 @@ const AnfNodePtr DynamicRNNGradReformat::Process(const FuncGraphPtr &func_graph, MS_EXCEPTION_IF_NULL(node); auto split_v = node->cast(); MS_EXCEPTION_IF_NULL(split_v); - auto matmul = CheckAnfNodeIfCNodeAndInputSize(split_v->input(1), 3); + auto matmul = CheckAnfNodeIfCNodeAndInputSize(split_v->input(1), kMatMulInputTensorNum); MS_EXCEPTION_IF_NULL(matmul); auto input_node_with_idx = AnfAlgo::GetPrevNodeOutput(matmul, 0); auto input_node = input_node_with_idx.first; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc index ed2dd6ffcc..b302150eb2 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_cast.cc @@ -129,9 +129,21 @@ AnfNodePtr ProcessGraphKernelOp(const FuncGraphPtr &func_graph, const AnfNodePtr auto mng = sub_graph->manager(); MS_EXCEPTION_IF_NULL(mng); std::vector todo; - std::vector> graph_rets; kernel::GetValidKernelNodes(sub_graph, &todo); - kernel::GetGraphRealOutput(sub_graph, &graph_rets); + auto outputs = AnfAlgo::GetAllOutput(sub_graph->output(), {prim::kPrimTupleGetItem}); + std::vector> graph_rets; + for (auto &output : outputs) { + size_t index = 0; + if (IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) { + ValuePtr tuple_index_value = GetValueNode(output->cast()->input(kInputNodeOutputIndexInTupleGetItem)); + MS_EXCEPTION_IF_NULL(tuple_index_value); + if (!tuple_index_value->isa()) { + MS_LOG(EXCEPTION) << "The index of tuple getitem is not int64"; + } + index = tuple_index_value->cast()->value(); + } + graph_rets.emplace_back(std::pair(output, index)); + } for (auto &t : todo) { AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), t); // process input diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.cc index 2a77999d45..a4ad2fff95 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/insert_transdata_for_runop.cc @@ -33,7 +33,8 @@ bool RunOpInsertTransData::Run(const FuncGraphPtr &graph) { continue; } auto cnode = node->cast(); - for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { + size_t input_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t index = 0; index < input_num; ++index) { auto prev_input_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); auto prev_node_out_infer_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, index); auto input_format = AnfAlgo::GetInputFormat(cnode, index); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.cc index e751a121f4..e9250248e4 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.cc @@ -28,8 +28,6 @@ namespace mindspore { namespace opt { namespace { -const size_t kCastInputNum = 2; -const size_t kTupleGetitemInputNum = 3; bool AlternativeKernelInfoForInput(const CNodePtr &node, const TypeId dst_type, const size_t change_idx, const std::shared_ptr &candidate_kernel_info) { if (node == nullptr || node->kernel_info() == nullptr || candidate_kernel_info == nullptr) { @@ -126,7 +124,8 @@ void ChangeNodeInferInfo(const CNodePtr &cnode, const CNodePtr &cast, const size auto cast_shape = AnfAlgo::GetOutputInferShape(cast, 0); std::vector shapes; std::vector types; - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); + for (size_t index = 0; index < output_num; ++index) { if (cast_index == index) { shapes.emplace_back(cast_shape); types.emplace_back(cast_dtype); @@ -175,7 +174,7 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co << "ori kernel info" << ori_kernel_info->ToString() << "alternative kernel info" << (*alternative_kernel_info)->ToString(); AnfAlgo::SetSelectKernelBuildInfo(*alternative_kernel_info, next_cnode.get()); - if (node->inputs().size() < kCastInputNum) { + if (AnfAlgo::GetInputTensorNum(node) < kCastInputTensorNum) { MS_LOG(EXCEPTION) << "Op[" << node->DebugString() << "] has wrong input num:"; } return node->input(1); @@ -188,9 +187,7 @@ bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_outpu *prior_op = x_cnode; // when x_node is tuple_getitem if (AnfAlgo::GetCNodeName(x_node) == prim::kPrimTupleGetItem->name()) { - if (x_cnode->inputs().size() < kTupleGetitemInputNum) { - MS_LOG(EXCEPTION) << "tuple getitem node has wrong input num" << x_cnode->inputs().size(); - } + CheckCNodeInputSize(x_cnode, kTupleGetItemInputTensorNum); MS_EXCEPTION_IF_NULL(output_idx); AnfNodePtr input1 = x_cnode->input(1); MS_EXCEPTION_IF_NULL(input1); @@ -214,9 +211,7 @@ bool GetPriorOp(const AnfNodePtr &x_node, CNodePtr *prior_op, bool *single_outpu AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_node, const KernelQueryPtr kernel_query) { MS_EXCEPTION_IF_NULL(cur_node); MS_EXCEPTION_IF_NULL(kernel_query); - if (cur_node->inputs().size() < kCastInputNum) { - MS_LOG(EXCEPTION) << "op[Cast] has wrong input num:"; - } + CheckCNodeInputSize(cur_node, kCastInputTensorNum); AnfNodePtr x_node = cur_node->input(1); if (IsUsedByOthers(graph, x_node)) { return nullptr; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc index f46dc9a433..fa677181ac 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/remove_internal_output.cc @@ -69,7 +69,7 @@ const AnfNodePtr RemoveInternalOutput::Process(const FuncGraphPtr &func_graph, c } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - CheckCNodeInputSize(cnode, kTransOpInputNum); + CheckCNodeInputSize(cnode, kTransOpInputTensorNum); auto input_node = cnode->input(1); if (!AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimTupleGetItem)) { kernel_graph->ReplaceInternalOutput(node, input_node); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc index f03e562c35..46dfd9dfd7 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_bert_fission.cc @@ -111,8 +111,8 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod auto bn_abstract_tuple = dyn_cast(bn->abstract()); MS_EXCEPTION_IF_NULL(bn_abstract_tuple); - if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) { - MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is " + if (bn_abstract_tuple->elements().size() != kBnOutputNum) { + MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBnOutputNum << ", but it is " << bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn); } std::vector abstract_list{bn_abstract_tuple->elements()[0], bn_abstract_tuple->elements()[3], diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.cc index b6d74c487d..ea729fd53e 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/batch_norm_grad_split.cc @@ -33,10 +33,7 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(bn_grad_node); const auto &bn_grad_inputs = bn_grad_node->inputs(); - if (bn_grad_inputs.size() < kBNGradInputNum) { - MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size." - << " trace: " << trace::DumpSourceLines(bn_grad_node); - } + CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum); std::vector bn_update_grad_inputs = { NewValueNode(std::make_shared(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2], bn_grad_inputs[4], bn_grad_inputs[5]}; @@ -60,10 +57,7 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra MS_EXCEPTION_IF_NULL(bn_grad_node); MS_EXCEPTION_IF_NULL(bn_reduce_grad_outputs); const auto &bn_grad_inputs = bn_grad_node->inputs(); - if (bn_grad_inputs.size() < kBNGradInputNum) { - MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size" - << " trace: " << trace::DumpSourceLines(bn_grad_node); - } + CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum); if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { MS_LOG(EXCEPTION) << "BNTrainingReduceGrad_outputs has wrong size" << " trace: " << trace::DumpSourceLines(bn_grad_node); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc index 8a1d7d3a7c..0cd9acec47 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_grad_split.cc @@ -33,10 +33,7 @@ void CreateOutputsOfUpdateGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(bn_grad_node); auto bn_grad_inputs = bn_grad_node->inputs(); - if (bn_grad_inputs.size() != kBNGradInputNum) { - MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size" - << " trace: " << trace::DumpSourceLines(bn_grad_node); - } + CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum); std::vector bn_update_grad_inputs = { NewValueNode(std::make_shared(kBNTrainingUpdateGradOpName)), bn_grad_inputs[1], bn_grad_inputs[2], bn_grad_inputs[4], bn_grad_inputs[5]}; @@ -59,10 +56,7 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(bn_grad_node); auto bn_grad_inputs = bn_grad_node->inputs(); - if (bn_grad_inputs.size() != kBNGradInputNum) { - MS_LOG(EXCEPTION) << "BNGrad has wrong inputs size" - << " trace: " << trace::DumpSourceLines(bn_grad_node); - } + CheckCNodeInputSize(bn_grad_node, kBNGradInputTensorNum); if (bn_update_grad_outputs.size() != kBNTrainingUpdateGradOutputNum) { MS_LOG(EXCEPTION) << "bn_update_grad_outputs has wrong size"; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc index 2b36a94733..2c4d8065c8 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/bn_split.cc @@ -32,8 +32,8 @@ bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr & std::vector *bn_training_reduce_outputs) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(bn_cnode); - if (bn_cnode->inputs().size() != kBnInputNum) { - MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputNum << ". " << bn_cnode->DebugString(); + if (AnfAlgo::GetInputTensorNum(bn_cnode) != kBnInputTensorNum) { + MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputTensorNum << ". " << bn_cnode->DebugString(); return false; } std::vector bn_training_reduce_inputs = { @@ -64,10 +64,7 @@ AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNod const std::vector &bn_training_reduce_outputs) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(bn_cnode); - if (bn_cnode->inputs().size() != kBnInputNum) { - MS_LOG(EXCEPTION) << "BN node has wrong input size" - << " trace: " << trace::DumpSourceLines(bn_cnode); - } + CheckCNodeInputSize(bn_cnode, kBnInputTensorNum); if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { MS_LOG(EXCEPTION) << "BN1 outputs has wrong input size" << " trace: " << trace::DumpSourceLines(bn_cnode); @@ -102,8 +99,8 @@ AnfNodePtr SplitBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNodePtr auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() < kBnInputNum) { - MS_LOG(INFO) << "op[FusedBatchNorm] has less than " << kBnInputNum << " inputs."; + if (AnfAlgo::GetInputTensorNum(cnode) < kBnInputTensorNum) { + MS_LOG(INFO) << "op[" << cnode->DebugString() << "] has less input than " << kBnInputTensorNum << " inputs."; return nullptr; } // Create BNTrainingReduce node and get outputs of BNTrainingReduce diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/gather_v2_ds_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/gather_v2_ds_fission.cc index c30520bb9d..8bea6ac3c4 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/gather_v2_ds_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/gather_v2_ds_fission.cc @@ -123,8 +123,8 @@ CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &gather_v2, const bool CheckInputs(const CNodePtr &origin_node) { MS_EXCEPTION_IF_NULL(origin_node); - if (origin_node->size() != kGatherV2DynInputNum + 1) { - MS_LOG(DEBUG) << "GatherV2 in dynamic shape has wrong inputs num, not equal " << kGatherV2DynInputNum + if (AnfAlgo::GetInputTensorNum(origin_node) != kGatherV2DynInputTensorNum) { + MS_LOG(DEBUG) << "GatherV2 in dynamic shape has wrong inputs num, not equal " << kGatherV2DynInputTensorNum << ". CNode= " << origin_node->DebugString(); return false; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.cc index 5fba8692db..3f394d5e9b 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/lars_v2_fission.cc @@ -28,11 +28,7 @@ void CreateOutputsOfSquareSumAll(const FuncGraphPtr &graph, const CNodePtr &lars std::vector *square_sum_all_outputs) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(lars_v2); - if (lars_v2->size() != kLarsV2InputNum) { - MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum - << " trace: " << trace::DumpSourceLines(lars_v2); - } - + CheckCNodeInputSize(lars_v2, kLarsV2InputTensorNum); std::vector inputs = {NewValueNode(std::make_shared(kSquareSumAllOpName)), lars_v2->input(1), lars_v2->input(2)}; auto square_sum_all = graph->NewCNode(inputs); @@ -55,10 +51,7 @@ CNodePtr CreateLarsV2Update(const FuncGraphPtr &graph, const CNodePtr &lars_v2, MS_LOG(EXCEPTION) << "square_sum_all_outputs' size not equal 2" << " trace: " << trace::DumpSourceLines(lars_v2); } - if (lars_v2->size() != kLarsV2InputNum) { - MS_LOG(EXCEPTION) << "Op lars_v2's input not equal " << kLarsV2InputNum - << " trace: " << trace::DumpSourceLines(lars_v2); - } + CheckCNodeInputSize(lars_v2, kLarsV2InputTensorNum); std::vector inputs = {NewValueNode(std::make_shared(kLarsV2UpdateOpName)), lars_v2->input(1), lars_v2->input(2), diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc index e4858d7a79..190e3a30ad 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/layer_norm_grad_split.cc @@ -91,7 +91,7 @@ const AnfNodePtr LayerNormGradSplit::Process(const FuncGraphPtr &graph, const An return nullptr; } auto cnode = node->cast(); - if (cnode->inputs().size() != kLayerNormGradInputNum) { + if (AnfAlgo::GetInputTensorNum(cnode) != kLayerNormGradInputTensorNum) { return nullptr; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc index 9c3ca1d085..ef4e7c33ac 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/reduce_min_fission.cc @@ -110,7 +110,7 @@ const AnfNodePtr ReduceMinFission::Process(const FuncGraphPtr &graph, const AnfN } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - CheckCNodeInputSize(cnode, 2); + CheckCNodeInputSize(cnode, 1); auto shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0); auto dtype = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, 0); auto prim = AnfAlgo::GetCNodePrimitive(cnode); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.cc index d58644d74c..489083d06f 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/single_batch_norm_fission.cc @@ -76,8 +76,8 @@ AnfNodePtr CreateBNTrainingUpdateV3(const FuncGraphPtr &func_graph, const AnfNod auto bn_abstract_tuple = dyn_cast(bn->abstract()); MS_EXCEPTION_IF_NULL(bn_abstract_tuple); - if (bn_abstract_tuple->elements().size() != kBatchNormOutputNum) { - MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBatchNormOutputNum << ", but it is " + if (bn_abstract_tuple->elements().size() != kBnOutputNum) { + MS_LOG(EXCEPTION) << "The abstract size of node bn must be " << kBnOutputNum << ", but it is " << bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn); } bn_training_update_v3->set_abstract(bn->abstract()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc index dc8b497303..9cac96110c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/split_fission.cc @@ -34,10 +34,7 @@ CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) { MS_EXCEPTION_IF_NULL(origin_cnode); - if (origin_cnode->inputs().size() < kSplitInputNum) { - MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be " - << kSplitInputNum - 1 << " trace: " << trace::DumpSourceLines(origin_cnode); - } + CheckCNodeInputSize(origin_cnode, kSplitInputTensorNum); return CreateSplitVNode(func_graph, origin_cnode->input(1)); } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/splitv_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/splitv_fission.cc index fcc0267a54..ef7099adef 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/splitv_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/splitv_fission.cc @@ -32,10 +32,7 @@ CNodePtr CreateSplitVNode(const FuncGraphPtr &func_graph, const AnfNodePtr &inpu CNodePtr CreateBaseSplitVNode(const FuncGraphPtr &func_graph, const CNodePtr &origin_cnode) { MS_EXCEPTION_IF_NULL(origin_cnode); - if (origin_cnode->inputs().size() < kSplitInputNum) { - MS_LOG(EXCEPTION) << "The input number of split: " << origin_cnode->DebugString() << " should be " - << kSplitInputNum - 1; - } + CheckCNodeInputSize(origin_cnode, kSplitInputTensorNum); return CreateSplitVNode(func_graph, origin_cnode->input(1)); } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc index f5b968e1a0..3ed140656a 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/topk_split.cc @@ -146,7 +146,7 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod new_cnode->set_abstract(cnode->abstract()); new_cnode->set_scope(cnode->scope()); AnfAlgo::CopyNodeAttrs(cnode, new_cnode); - CheckCNodeInputSize(new_cnode, kTopkInputNum); + CheckCNodeInputSize(new_cnode, kTopkInputTensorNum); // Convert the tensor input to scalar and convert it to attr auto input_k = new_cnode->input(kTopkIndexK + 1); MS_EXCEPTION_IF_NULL(input_k); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc index cce479e95a..10abf6d093 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/transdata_split.cc @@ -31,7 +31,7 @@ const AnfNodePtr TransDataSplit::Process(const FuncGraphPtr &func_graph, const A const EquivPtr &) const { MS_EXCEPTION_IF_NULL(func_graph); if (node != nullptr && node->isa() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { - CheckCNodeInputSize(node->cast(), kBackendTransDataInputNum); + CheckCNodeInputSize(node->cast(), kTransOpInputTensorNum); if (IsFormatInvaild(node)) { TraceGuard guard(std::make_shared(node->debug_info())); return DoSplit(func_graph, node); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.cc index 9a129c7911..0896682271 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fission/unsorted_segment_sum_fission.cc @@ -77,8 +77,8 @@ CNodePtr CreateSlice(const FuncGraphPtr &graph, const CNodePtr &unsort_segment_s bool CheckInputs(const CNodePtr &origin_node) { MS_EXCEPTION_IF_NULL(origin_node); - if (origin_node->size() != kUnsortedSegmentSumInputNum + 1) { - MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputNum + if (AnfAlgo::GetInputTensorNum(origin_node) != kUnsortedSegmentSumInputTensorNum) { + MS_LOG(DEBUG) << "UnsortedSegmentSum has wrong inputs num, not equal " << kUnsortedSegmentSumInputTensorNum << ". CNode= " << origin_node->DebugString(); return false; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc index 9e98d4ff18..d45ec995b1 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnorm_to_bninfer.cc @@ -62,8 +62,8 @@ bool CheckIndex(const AnfNodePtr &index_node) { bool CheckBatchNorm(const FuncGraphPtr &graph, const CNodePtr &batchnorm) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(batchnorm); - if (batchnorm->size() < kBatchNormInputNum + 1) { - MS_LOG(DEBUG) << "BatchNorm's input less than " << kBatchNormInputNum; + if (AnfAlgo::GetInputTensorNum(batchnorm) < kBnInputTensorNum) { + MS_LOG(DEBUG) << "BatchNorm's input less than " << kBnInputTensorNum; return false; } if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnorm)) { @@ -87,7 +87,7 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat MS_EXCEPTION_IF_NULL(node); auto tuple_getitem = node->cast(); MS_EXCEPTION_IF_NULL(tuple_getitem); - CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize); + CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputTensorNum); AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); MS_EXCEPTION_IF_NULL(index_node); if (!CheckIndex(index_node)) { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc index 8ad1f5ce36..cddc3b6a56 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/batchnormgrad_to_bninfergrad.cc @@ -61,8 +61,8 @@ bool CheckIndex(const AnfNodePtr &index_node) { bool CheckBatchNormGrad(const FuncGraphPtr &graph, const CNodePtr &batchnormgrad) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(batchnormgrad); - if (batchnormgrad->size() < kBatchNormInputNum + 1) { - MS_LOG(DEBUG) << "BatchNormGrad's input less than " << kBatchNormInputNum; + if (AnfAlgo::GetInputTensorNum(batchnormgrad) < kBNGradInputTensorNum) { + MS_LOG(DEBUG) << "BatchNormGrad's input less than " << kBnInputTensorNum; return false; } if (!AnfAlgo::HasNodeAttr(kAttrIsTraining, batchnormgrad)) { @@ -86,7 +86,7 @@ bool NeedFusion(const FuncGraphPtr &graph, const AnfNodePtr &node, CNodePtr *bat MS_EXCEPTION_IF_NULL(node); auto tuple_getitem = node->cast(); MS_EXCEPTION_IF_NULL(tuple_getitem); - CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputSize); + CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputTensorNum); AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); MS_EXCEPTION_IF_NULL(index_node); if (!CheckIndex(index_node)) { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.cc index e1b0cb81e3..780fb0903c 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/clip_by_value_fusion.cc @@ -79,7 +79,7 @@ const AnfNodePtr ClipByValueFusion::Process(const FuncGraphPtr &graph, const Anf return nullptr; } MS_EXCEPTION_IF_NULL(minimum); - if (minimum->inputs().size() != kMinimumInputNum) { + if (AnfAlgo::GetInputTensorNum(minimum) != kMinimumInputTensorNum) { return nullptr; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc index fa111405c6..cd032a207e 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/derelu_fusion.cc @@ -30,9 +30,7 @@ const size_t kReluV2OutputNum = 2; CNodePtr GetRelu(const CNodePtr &relu_grad) { MS_EXCEPTION_IF_NULL(relu_grad); - if (relu_grad->size() != kReluGradInputNum) { - MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size(); - } + CheckCNodeInputSize(relu_grad, kReluGradInputTensorNum); auto relu_anf = relu_grad->input(2); MS_EXCEPTION_IF_NULL(relu_anf); return relu_anf->cast(); @@ -41,9 +39,7 @@ CNodePtr GetRelu(const CNodePtr &relu_grad) { CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(relu); - if (relu->size() != kReluInputNum) { - MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size(); - } + CheckCNodeInputSize(relu, kReluInputTensorNum); auto prim = std::make_shared(kReluV2OpName); std::vector inputs = {NewValueNode(prim), relu->input(1)}; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.cc index 4bb5d1159c..78e2d3e5fa 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.cc @@ -53,32 +53,9 @@ void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vect } } // namespace -const BaseRef FusedBatchNormFusion::DefinePattern() const { - std::shared_ptr Xs = std::make_shared(); - VarPtr index0 = std::make_shared(IsC); - VarPtr index1 = std::make_shared(IsC); - VarPtr index2 = std::make_shared(IsC); - VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); - VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); - VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); - VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); - VectorRef sub0 = VectorRef({prim::kPrimSub, variable_input0_var_, tuple_getitem1}); - VectorRef sub1 = VectorRef({prim::kPrimSub, variable_input1_var_, tuple_getitem2}); - VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_}); - VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); - VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0}); - VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1}); - VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); - return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); -} - ValuePtr FusedBatchNormFusion::GetFactor(const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(equiv); - auto iter_constant_input0 = (*equiv).find(constant_input0_var_); - if (iter_constant_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the constant_input0 var after matched."; - } - auto constant_input = utils::cast(iter_constant_input0->second); + auto constant_input = GetAnfNodeByVar(equiv, constant_input0_var_); MS_EXCEPTION_IF_NULL(constant_input); if (!constant_input->isa()) { return nullptr; @@ -113,31 +90,15 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(equiv); // Set input to create node - auto iter_data_input0 = (*equiv).find(data_input0_var_); - if (iter_data_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched." - << " trace: " << trace::DumpSourceLines(node); - } std::vector bn_training_reduce_inputs = { - NewValueNode(std::make_shared(kBNTrainingReduceOpName)), - utils::cast(iter_data_input0->second)}; + NewValueNode(std::make_shared(kBNTrainingReduceOpName)), GetAnfNodeByVar(equiv, data_input0_var_)}; auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs); MS_EXCEPTION_IF_NULL(bn_training_reduce); bn_training_reduce->set_scope(node->scope()); // Set abstract - auto iter_data_input1 = (*equiv).find(data_input1_var_); - if (iter_data_input1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched." - << " trace: " << trace::DumpSourceLines(node); - } - auto data_input1 = utils::cast(iter_data_input1->second); + auto data_input1 = GetAnfNodeByVar(equiv, data_input1_var_); MS_EXCEPTION_IF_NULL(data_input1); - auto iter_data_input2 = (*equiv).find(data_input2_var_); - if (iter_data_input2 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched." - << " trace: " << trace::DumpSourceLines(node); - } - auto data_input2 = utils::cast(iter_data_input2->second); + auto data_input2 = GetAnfNodeByVar(equiv, data_input2_var_); MS_EXCEPTION_IF_NULL(data_input2); AbstractBasePtrList abstract_list{data_input1->abstract(), data_input2->abstract()}; auto abstract_tuple = std::make_shared(abstract_list); @@ -150,39 +111,15 @@ void FusedBatchNormFusion::GetBNTrainingUpdateInputs(const EquivPtr &equiv, std::vector *bn_training_update_inputs) const { MS_EXCEPTION_IF_NULL(equiv); MS_EXCEPTION_IF_NULL(bn_training_update_inputs); - auto iter_data_input0 = (*equiv).find(data_input0_var_); - if (iter_data_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched."; - } - auto iter_data_input1 = (*equiv).find(data_input1_var_); - if (iter_data_input1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched."; - } - auto iter_data_input2 = (*equiv).find(data_input2_var_); - if (iter_data_input2 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched."; - } - auto iter_variable_input0 = (*equiv).find(variable_input0_var_); - if (iter_variable_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched."; - } - auto iter_variable_input1 = (*equiv).find(variable_input1_var_); - if (iter_variable_input1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched."; - } - if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) { - MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum - << ", but it is " << bn_training_reduce_outputs.size(); - } *bn_training_update_inputs = { NewValueNode(std::make_shared(kBNTrainingUpdateOpName)), - utils::cast(iter_data_input0->second), + utils::cast(GetAnfNodeByVar(equiv, data_input0_var_)), bn_training_reduce_outputs[0], bn_training_reduce_outputs[1], - utils::cast(iter_data_input1->second), - utils::cast(iter_data_input2->second), - utils::cast(iter_variable_input0->second), - utils::cast(iter_variable_input1->second), + GetAnfNodeByVar(equiv, data_input1_var_), + GetAnfNodeByVar(equiv, data_input2_var_), + GetAnfNodeByVar(equiv, variable_input0_var_), + GetAnfNodeByVar(equiv, variable_input1_var_), }; } @@ -197,19 +134,9 @@ void FusedBatchNormFusion::GetBNTrainingUpdateAbstractList(const EquivPtr &equiv MS_LOG(EXCEPTION) << "The abstract size of node bn must not be less than " << kBnOutputNum << ", but it is " << bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn); } - auto iter_variable_input0 = (*equiv).find(variable_input0_var_); - if (iter_variable_input0 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched." - << " trace: " << trace::DumpSourceLines(bn); - } - auto variable_input0 = utils::cast(iter_variable_input0->second); + auto variable_input0 = GetAnfNodeByVar(equiv, variable_input0_var_); + auto variable_input1 = GetAnfNodeByVar(equiv, variable_input1_var_); MS_EXCEPTION_IF_NULL(variable_input0); - auto iter_variable_input1 = (*equiv).find(variable_input1_var_); - if (iter_variable_input1 == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched." - << " trace: " << trace::DumpSourceLines(bn); - } - auto variable_input1 = utils::cast(iter_variable_input1->second); MS_EXCEPTION_IF_NULL(variable_input1); *abstract_list = {bn_abstract_tuple->elements()[0], variable_input0->abstract(), variable_input1->abstract(), bn_abstract_tuple->elements()[1], bn_abstract_tuple->elements()[2]}; @@ -227,13 +154,7 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate( auto bn_training_update = func_graph->NewCNode(bn_training_update_inputs); MS_EXCEPTION_IF_NULL(bn_training_update); // Set abstract - auto iter_batch_norm = (*equiv).find(batch_norm_var_); - if (iter_batch_norm == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched." - << " trace: " << trace::DumpSourceLines(node); - } - AnfNodePtr bn = utils::cast(iter_batch_norm->second); - MS_EXCEPTION_IF_NULL(bn); + AnfNodePtr bn = GetAnfNodeByVar(equiv, batch_norm_var_); AbstractBasePtrList abstract_list; GetBNTrainingUpdateAbstractList(equiv, bn, &abstract_list); auto abstract_tuple = std::make_shared(abstract_list); @@ -249,6 +170,23 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate( return bn_training_update; } +void FusedBatchNormFusion::EliminateMonadNodes(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto assign_sub1 = GetAnfNodeByVar(equiv, assign_sub1_var_); + MS_EXCEPTION_IF_NULL(assign_sub1); + for (const auto &node_index : manager->node_users()[assign_sub1]) { + const AnfNodePtr &output = node_index.first; + MS_EXCEPTION_IF_NULL(output); + if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimUpdateState)) { + (void)manager->Replace(output, GetAnfNodeByVar(equiv, monad0_var_)); + break; + } + } +} + const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); @@ -271,14 +209,8 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c << bn_training_update_outputs.size() << " trace: " << trace::DumpSourceLines(node); } // Replace old bn outputs with new outputs - auto iter_batch_norm = (*equiv).find(batch_norm_var_); - if (iter_batch_norm == (*equiv).end()) { - MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched." - << " trace: " << trace::DumpSourceLines(node); - } - AnfNodePtr bn = utils::cast(iter_batch_norm->second); std::vector bn_outputs; - GetBNOutput(func_graph, bn, &bn_outputs); + GetBNOutput(func_graph, GetAnfNodeByVar(equiv, batch_norm_var_), &bn_outputs); auto manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); for (const auto &output : bn_outputs) { @@ -297,7 +229,28 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c (void)manager->Replace(output, bn_training_update_outputs[index]); } } - return bn_training_update_outputs[0]; + (void)manager->Replace(node, bn_training_update_outputs[0]); + EliminateMonadNodes(func_graph, equiv); + return nullptr; +} + +const BaseRef FusedBatchNormFusion::DefinePattern() const { + std::shared_ptr Xs = std::make_shared(); + VarPtr index0 = std::make_shared(IsC); + VarPtr index1 = std::make_shared(IsC); + VarPtr index2 = std::make_shared(IsC); + VectorRef batch_norm = VectorRef({batch_norm_var_, data_input0_var_, data_input1_var_, data_input2_var_, Xs}); + VectorRef tuple_getitem0 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index0}); + VectorRef tuple_getitem1 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index1}); + VectorRef tuple_getitem2 = VectorRef({prim::kPrimTupleGetItem, batch_norm, index2}); + VectorRef sub0 = VectorRef({prim::kPrimSub, variable_input0_var_, tuple_getitem1}); + VectorRef sub1 = VectorRef({prim::kPrimSub, variable_input1_var_, tuple_getitem2}); + VectorRef mul0 = VectorRef({prim::kPrimMul, sub0, constant_input0_var_}); + VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); + VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, mul0, monad0_var_}); + VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, mul1, monad1_var_}); + VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); + return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); } const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const { @@ -317,8 +270,8 @@ const BaseRef FusedBatchNormMixPrecisionFusion0::DefinePattern() const { VectorRef mul1 = VectorRef({prim::kPrimMul, sub1, constant_input1_var_}); VectorRef cast2 = VectorRef({prim::kPrimCast, mul0}); VectorRef cast3 = VectorRef({prim::kPrimCast, mul1}); - VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, cast2}); - VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, cast3}); + VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, cast2, monad0_var_}); + VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, cast3, monad1_var_}); VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); } @@ -340,8 +293,8 @@ const BaseRef FusedBatchNormMixPrecisionFusion1::DefinePattern() const { VectorRef cast1 = VectorRef({prim::kPrimCast, sub1}); VectorRef mul0 = VectorRef({prim::kPrimMul, cast0, constant_input0_var_}); VectorRef mul1 = VectorRef({prim::kPrimMul, cast1, constant_input1_var_}); - VectorRef assign_sub0 = VectorRef({prim::kPrimAssignSub, variable_input0_var_, mul0}); - VectorRef assign_sub1 = VectorRef({prim::kPrimAssignSub, variable_input1_var_, mul1}); + VectorRef assign_sub0 = VectorRef({assign_sub0_var_, variable_input0_var_, mul0, monad0_var_}); + VectorRef assign_sub1 = VectorRef({assign_sub1_var_, variable_input1_var_, mul1, monad1_var_}); VectorRef depend0 = VectorRef({prim::kPrimDepend, tuple_getitem0, assign_sub0}); return VectorRef({prim::kPrimDepend, depend0, assign_sub1}); } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h index 04cbed35f2..beafc367d2 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/fused_batch_norm_fusion.h @@ -27,15 +27,20 @@ namespace opt { class FusedBatchNormFusion : public PatternProcessPass { public: explicit FusedBatchNormFusion(const std::string &name = "fused_batch_norm_fusion", bool multigraph = true) - : PatternProcessPass(name, multigraph), - data_input0_var_(std::make_shared()), - data_input1_var_(std::make_shared()), - data_input2_var_(std::make_shared()), - variable_input0_var_(std::make_shared()), - variable_input1_var_(std::make_shared()), - constant_input0_var_(std::make_shared()), - constant_input1_var_(std::make_shared()), - batch_norm_var_(std::make_shared(std::make_shared(prim::kPrimBatchNorm->name()))) {} + : PatternProcessPass(name, multigraph) { + data_input0_var_ = std::make_shared(); + data_input1_var_ = std::make_shared(); + data_input2_var_ = std::make_shared(); + variable_input0_var_ = std::make_shared(); + variable_input1_var_ = std::make_shared(); + constant_input0_var_ = std::make_shared(); + constant_input1_var_ = std::make_shared(); + batch_norm_var_ = std::make_shared(std::make_shared(prim::kPrimBatchNorm->name())); + assign_sub0_var_ = std::make_shared(std::make_shared(prim::kPrimAssignSub->name())); + assign_sub1_var_ = std::make_shared(std::make_shared(prim::kPrimAssignSub->name())); + monad0_var_ = std::make_shared(); + monad1_var_ = std::make_shared(); + } ~FusedBatchNormFusion() override = default; const BaseRef DefinePattern() const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; @@ -50,6 +55,7 @@ class FusedBatchNormFusion : public PatternProcessPass { AnfNodePtr CreateBNTrainingUpdate(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv, const std::vector &bn_training_reduce_outputs) const; ValuePtr GetFactor(const EquivPtr &equiv) const; + void EliminateMonadNodes(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const; VarPtr data_input0_var_; VarPtr data_input1_var_; @@ -59,6 +65,10 @@ class FusedBatchNormFusion : public PatternProcessPass { VarPtr constant_input0_var_; VarPtr constant_input1_var_; VarPtr batch_norm_var_; + VarPtr assign_sub0_var_; + VarPtr assign_sub1_var_; + VarPtr monad0_var_; + VarPtr monad1_var_; }; class FusedBatchNormMixPrecisionFusion0 : public FusedBatchNormFusion { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc index 313e88af3f..5c31af153b 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/lamb_next_mv_with_decay_v1_rule.cc @@ -30,33 +30,21 @@ std::tuple GetSharedNodes(const MS_EXCEPTION_IF_NULL(node); auto add3 = node->cast(); MS_EXCEPTION_IF_NULL(add3); - if (add3->inputs().size() < kAddInputNum) { - MS_LOG(EXCEPTION) << "The input size of Add3 is less than " << kAddInputNum - << " trace: " << trace::DumpSourceLines(node); - } + CheckCNodeInputSize(add3, kAddInputTensorNum); auto real_div2_anf = add3->input(1); MS_EXCEPTION_IF_NULL(real_div2_anf); auto real_div2 = real_div2_anf->cast(); MS_EXCEPTION_IF_NULL(real_div2); - if (real_div2->inputs().size() < kRealDivInputNum) { - MS_LOG(EXCEPTION) << "The input size of RealDiv2 is less than " << kRealDivInputNum - << " trace: " << trace::DumpSourceLines(node); - } + CheckCNodeInputSize(real_div2, kRealDivInputTensorNum); auto sqrt0_anf = real_div2->input(2); MS_EXCEPTION_IF_NULL(sqrt0_anf); auto sqrt0 = sqrt0_anf->cast(); MS_EXCEPTION_IF_NULL(sqrt0); - if (sqrt0->inputs().size() < kRsqrtInputNum) { - MS_LOG(EXCEPTION) << "The input size of Sqrt0 is less than " << kSqrtInputNum - << " trace: " << trace::DumpSourceLines(node); - } + CheckCNodeInputSize(sqrt0, kSqrtInputTensorNum); auto add2_anf = sqrt0->input(1); MS_EXCEPTION_IF_NULL(add2_anf); auto add2 = add2_anf->cast(); - if (add2->inputs().size() < kAddInputNum) { - MS_LOG(EXCEPTION) << "The input size of Add2 is less than " << kAddInputNum - << " trace: " << trace::DumpSourceLines(node); - } + CheckCNodeInputSize(add2, kAddInputTensorNum); return std::make_tuple(add3->input(2), real_div2->input(1), add2->input(1), add2->input(2)); } @@ -66,7 +54,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN return false; } auto add5 = node->cast(); - if (AnfAlgo::GetCNodeName(add5) != prim::kPrimAdd->name() || add5->inputs().size() != kAddInputNum) { + if (AnfAlgo::GetCNodeName(add5) != prim::kPrimAdd->name() || AnfAlgo::GetInputTensorNum(add5) != kAddInputTensorNum) { return false; } auto real_div4_anf = add5->input(1); @@ -74,7 +62,8 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN return false; } auto real_div4 = real_div4_anf->cast(); - if (AnfAlgo::GetCNodeName(real_div4) != kRealDivOpName || real_div4->inputs().size() != kRealDivInputNum) { + if (AnfAlgo::GetCNodeName(real_div4) != kRealDivOpName || + AnfAlgo::GetInputTensorNum(real_div4) != kRealDivInputTensorNum) { return false; } auto add4_anf = real_div4->input(2); @@ -82,7 +71,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN return false; } auto add4 = add4_anf->cast(); - if (AnfAlgo::GetCNodeName(add4) != prim::kPrimAdd->name() || add4->inputs().size() != kAddInputNum) { + if (AnfAlgo::GetCNodeName(add4) != prim::kPrimAdd->name() || AnfAlgo::GetInputTensorNum(add4) != kAddInputTensorNum) { return false; } auto sqrt1_anf = add4->input(1); @@ -90,7 +79,7 @@ bool MatchAdd5Pattern(const AnfNodePtr &node, const AnfNodePtr &mul4, const AnfN return false; } auto sqrt1 = sqrt1_anf->cast(); - if (AnfAlgo::GetCNodeName(sqrt1) != kSqrtOpName || sqrt1->inputs().size() != kSqrtInputNum) { + if (AnfAlgo::GetCNodeName(sqrt1) != kSqrtOpName || AnfAlgo::GetInputTensorNum(sqrt1) != kSqrtInputTensorNum) { return false; } return add5->input(2) == mul4 && real_div4->input(1) == real_div0 && sqrt1->input(1) == real_div1 && @@ -104,14 +93,8 @@ std::tuple GetAdd0Add1Nodes(const AnfNodePtr &real_div0_ auto real_div1 = real_div1_anf->cast(); MS_EXCEPTION_IF_NULL(real_div0); MS_EXCEPTION_IF_NULL(real_div1); - if (real_div0->inputs().size() != kRealDivInputNum) { - MS_LOG(EXCEPTION) << "RealDiv0 has wrong input size" - << " trace: " << trace::DumpSourceLines(real_div0_anf); - } - if (real_div1->inputs().size() != kRealDivInputNum) { - MS_LOG(EXCEPTION) << "RealDiv1 has wrong input size" - << " trace: " << trace::DumpSourceLines(real_div1_anf); - } + CheckCNodeInputSize(real_div0, kRealDivInputTensorNum); + CheckCNodeInputSize(real_div1, kRealDivInputTensorNum); return std::make_tuple(real_div0->input(1), real_div1->input(1)); } } // namespace diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.cc index 8a23ac2f84..5822f8aa31 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/layer_norm_beta_gamma_backprop_fusion.cc @@ -77,9 +77,9 @@ bool CheckLayernormBetaGammaBackprop(const FuncGraphPtr &func_graph, const CNode MS_LOG(INFO) << "The node " << cnode->DebugString() << " has no " << kAttrShapeGamma << " attr"; return false; } - if (cnode->inputs().size() != kLayerNormBetaGammaBackpropInputNum) { + if (AnfAlgo::GetInputTensorNum(cnode) != kLayerNormBetaGammaBackpropInputTensorNum) { MS_LOG(INFO) << "The node " << cnode->DebugString() << " inputs num is not equal to " - << kLayerNormBetaGammaBackpropInputNum; + << kLayerNormBetaGammaBackpropInputTensorNum; return false; } if (AnfAlgo::GetOutputTensorNum(cnode) != kLayerNormBetaGammaBackpropOutputNum) { @@ -87,7 +87,8 @@ bool CheckLayernormBetaGammaBackprop(const FuncGraphPtr &func_graph, const CNode << kLayerNormBetaGammaBackpropOutputNum; return false; } - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); ++i) { + size_t input_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t i = 0; i < input_num; ++i) { if (AnfAlgo::GetInputDeviceDataType(cnode, i) != kNumberTypeFloat16) { MS_LOG(INFO) << "The data type of node " << cnode->DebugString() << " input " << i << " is not float16"; return false; @@ -148,15 +149,9 @@ const AnfNodePtr LayerNormBetaGammaBackpropFusion::Process(const FuncGraphPtr &f // The cast_nodes size has been checked above. MS_EXCEPTION_IF_NULL(cast_nodes[0]); MS_EXCEPTION_IF_NULL(cast_nodes[1]); - if (cast_nodes[0]->inputs().size() != kCastInputNum) { - MS_LOG(EXCEPTION) << "The cast0 " << cast_nodes[0]->DebugString() << " input size should be " << kCastInputNum - << " trace: " << trace::DumpSourceLines(node); - } + CheckCNodeInputSize(cast_nodes[0], kCastInputTensorNum); + CheckCNodeInputSize(cast_nodes[1], kCastInputTensorNum); (void)manager->Replace(cast_nodes[0], cast_nodes[0]->input(1)); - if (cast_nodes[1]->inputs().size() != kCastInputNum) { - MS_LOG(EXCEPTION) << "The cast1 " << cast_nodes[1]->DebugString() << " input size should be " << kCastInputNum - << " trace: " << trace::DumpSourceLines(node); - } (void)manager->Replace(cast_nodes[1], cast_nodes[1]->input(1)); return nullptr; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc index dd729c5bc2..db3e1f9822 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.cc @@ -31,6 +31,20 @@ const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const A const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(graph); + + auto matmul = GetAnfNodeByVar(equiv, matmul_var_); + if (matmul == nullptr || !matmul->isa()) { + MS_LOG(EXCEPTION) << "Get CNode MatMul failed!" + << " trace: " << trace::DumpSourceLines(node); + } + + // If there is a side-effect operator in the fusion, do not merge + MonadState state_matmul = GetMonadState(matmul); + MonadState state_node = GetMonadState(node, matmul); + if (!IsStateEquivalent(state_matmul, state_node)) { + return node; + } + std::vector inputs; inputs.emplace_back(NewValueNode(std::make_shared(prim::kPrimMatMul->name()))); inputs.emplace_back(GetAnfNodeByVar(equiv, x0_)); @@ -41,11 +55,6 @@ const AnfNodePtr MatmulBiasaddFusion::Process(const FuncGraphPtr &graph, const A new_node->set_scope(node->scope()); new_node->set_abstract(node->abstract()); - auto matmul = GetAnfNodeByVar(equiv, matmul_var_); - if (matmul == nullptr || !matmul->isa()) { - MS_LOG(EXCEPTION) << "Get CNode MatMul failed!" - << " trace: " << trace::DumpSourceLines(node); - } AnfAlgo::CopyNodeAttrs(matmul, new_node); return new_node; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.cc index 9c3a59ed7c..4905b946d3 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.cc @@ -43,7 +43,9 @@ const BaseRef MomentumLossscaleFusion::DefinePattern() const { VarPtr X1 = std::make_shared(); VarPtr X2 = std::make_shared(); VarPtr X4 = std::make_shared(); - return VectorRef({prim::kPrimApplyMomentum, X0, X1, X2, VectorRef({prim::kPrimMul, Xs}), X4}); + // UpdateState node + VarPtr X5 = std::make_shared(); + return VectorRef({prim::kPrimApplyMomentum, X0, X1, X2, VectorRef({prim::kPrimMul, Xs}), X4, X5}); } const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, @@ -52,14 +54,15 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - CheckCNodeInputSize(cnode, kApplyMomentumInputNum); + CheckCNodeInputSize(cnode, kApplyMomentumInputTensorNum); AnfNodePtr mul = cnode->input(4); MS_EXCEPTION_IF_NULL(mul); auto mul_cnode = mul->cast(); MS_EXCEPTION_IF_NULL(mul_cnode); - CheckCNodeInputSize(mul_cnode, kMulInputNum); + CheckCNodeInputSize(mul_cnode, kMulInputTensorNum); size_t value_node_index = 0; - for (size_t i = 1; i < kMulInputNum; ++i) { + // All real inputs include 1prim + x*TensorInput + for (size_t i = 1; i < kMulInputTensorNum + 1; ++i) { if (CheckValueNodeInputOfMul(mul_cnode->input(i))) { value_node_index = i; break; @@ -70,12 +73,16 @@ const AnfNodePtr MomentumLossscaleFusion::Process(const FuncGraphPtr &func_graph return nullptr; } auto new_prim = std::make_shared(kFusedMulApplyMomentumOpName); + auto depend_prim = NewValueNode(prim::kPrimDepend); + auto depend = func_graph->NewCNode({depend_prim, cnode->input(5), cnode->input(6)}); // depend on monad + depend->set_abstract(cnode->input(5)->abstract()); + depend->set_scope(cnode->input(5)->scope()); std::vector new_node_inputs{NewValueNode(new_prim), cnode->input(1), cnode->input(2), cnode->input(3), - mul_cnode->input(kMulInputNum - value_node_index), - cnode->input(5), + mul_cnode->input(kMulInputTensorNum + 1 - value_node_index), + depend, mul_cnode->input(value_node_index)}; auto new_node = func_graph->NewCNode(new_node_inputs); MS_EXCEPTION_IF_NULL(new_node); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.cc index e6e146f4a0..85599b6975 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_add_fusion.cc @@ -67,7 +67,7 @@ const AnfNodePtr MulAddFusion::Process(const FuncGraphPtr &graph, const AnfNodeP return nullptr; } auto add = node->cast(); - if (add == nullptr || add->inputs().size() != kAddInputNum) { + if (add == nullptr || AnfAlgo::GetInputTensorNum(add) != kAddInputTensorNum) { return nullptr; } CNodePtr mul = nullptr; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.cc index a0f803e6a3..ea0c73ac90 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/mul_addn_fusion.cc @@ -31,7 +31,7 @@ CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const MS_EXCEPTION_IF_NULL(addn); auto prim = std::make_shared(kFusedMulAddNOpName); std::vector inputs = {NewValueNode(prim)}; - inputs.push_back(mul->input(kMulInputNum - lossscale_input_index)); + inputs.push_back(mul->input(kMulInputTensorNum + 1 - lossscale_input_index)); inputs.push_back(addn->input(2)); // scalar input should be 3rd input inputs.push_back(mul->input(lossscale_input_index)); @@ -60,7 +60,7 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode } auto addn = node->cast(); - if (addn == nullptr || addn->inputs().size() != kAddNInputNum) { + if (addn == nullptr) { return nullptr; } auto mul_anf = addn->input(1); @@ -68,7 +68,7 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode return nullptr; } auto mul = mul_anf->cast(); - if (mul == nullptr || mul->inputs().size() != kMulInputNum) { + if (mul == nullptr || AnfAlgo::GetInputTensorNum(mul) != kMulInputTensorNum) { return nullptr; } if (IsUsedByOthers(graph, mul)) { diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc index 501327a261..4b467bc871 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc @@ -98,7 +98,8 @@ bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { MS_LOG(DEBUG) << "Skip trans op"; continue; } - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) { + size_t input_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t input_index = 0; input_index < input_num; input_index++) { std::vector trans_road; bool first_flag = true; auto final_node = ParamTransRoad(func_graph, AnfAlgo::GetInputNode(cnode, input_index), first_flag, &trans_road); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.cc index b6b179b625..1170f77c19 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/refresh_parameter_format.cc @@ -26,8 +26,10 @@ void DoRefresh(const CNodePtr &cnode) { if (cnode == nullptr) { MS_LOG(EXCEPTION) << "node is nullptr"; } - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); input_index++) { - auto input_kernel_node = AnfAlgo::GetInputNode(cnode, input_index); + size_t input_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t input_index = 0; input_index < input_num; input_index++) { + auto input_kernel_node = AnfAlgo::VisitKernel(AnfAlgo::GetInputNode(cnode, input_index), 0).first; + MS_EXCEPTION_IF_NULL(input_kernel_node); if (input_kernel_node->isa()) { std::shared_ptr builder = std::make_shared(); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc index 706c8e4fc0..ea894025fa 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/remove_reshape_pair.cc @@ -34,13 +34,14 @@ const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, cons const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(equiv); - auto out_reshape = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); + auto out_reshape = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputTensorNum); MS_EXCEPTION_IF_NULL(out_reshape); // If reshape operator used by more than one other operators, reshape operator cant not be deleted directly if (IsUsedByOthers(func_graph, out_reshape)) { return nullptr; } - auto in_reshape = CheckAnfNodeIfCNodeAndInputSize(AnfAlgo::GetInputNode(out_reshape, 0), kBackendReshapeInputNum); + auto in_reshape = + CheckAnfNodeIfCNodeAndInputSize(AnfAlgo::GetInputNode(out_reshape, 0), kBackendReshapeInputTensorNum); MS_EXCEPTION_IF_NULL(in_reshape); if (IsUsedByOthers(func_graph, in_reshape)) { return nullptr; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.cc index fd41cfe56f..2e731dcf18 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/reshape_transpose_fusion.cc @@ -46,9 +46,9 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(equiv); - auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); + auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputTensorNum); MS_EXCEPTION_IF_NULL(transpose_cnode); - auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum); + auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputTensorNum); MS_EXCEPTION_IF_NULL(reshape_cnode); if (AnfAlgo::IsDynamicShape(transpose_cnode) || AnfAlgo::IsDynamicShape(reshape_cnode)) { return nullptr; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.cc index de2cb2e287..9912b4b988 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/square_sum_fusion.cc @@ -33,10 +33,7 @@ CNodePtr GenerateSquareSumV1(const FuncGraphPtr &graph, const CNodePtr &square, MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(square); MS_EXCEPTION_IF_NULL(sum); - if (square->inputs().size() != kSquareNodeInputNum) { - MS_LOG(EXCEPTION) << "Square node has wrong input size" - << " trace: " << trace::DumpSourceLines(square); - } + CheckCNodeInputSize(square, kSquareNodeInputTensorNum); auto prim = std::make_shared(kSquareSumV1OpName); MS_EXCEPTION_IF_NULL(prim); std::vector square_sumv1_inputs = {NewValueNode(prim), square->input(1)}; @@ -60,10 +57,7 @@ CNodePtr GenerateSquareSumV2(const FuncGraphPtr &graph, const CNodePtr &square, MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(square); MS_EXCEPTION_IF_NULL(sum); - if (square->inputs().size() != kSquareNodeInputNum) { - MS_LOG(EXCEPTION) << "Square node has wrong input size" - << " trace: " << trace::DumpSourceLines(square); - } + CheckCNodeInputSize(square, kSquareNodeInputTensorNum); auto prim = std::make_shared(kSquareSumV2OpName); MS_EXCEPTION_IF_NULL(prim); std::vector square_sumv2_inputs = {NewValueNode(prim), square->input(1)}; @@ -84,10 +78,7 @@ std::tuple GetPrevNodes(const AnfNodePtr &node) MS_EXCEPTION_IF_NULL(node); auto sum = node->cast(); MS_EXCEPTION_IF_NULL(sum); - if (sum->inputs().size() != kSumNodeInputNum) { - MS_LOG(EXCEPTION) << "ReduceSumD node has wrong input size" - << " trace: " << trace::DumpSourceLines(sum); - } + CheckCNodeInputSize(sum, kSumNodeInputTensorNum); auto square_anf = sum->input(1); MS_EXCEPTION_IF_NULL(square_anf); auto square = square_anf->cast(); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.cc index 0643537d02..095d6ba040 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_reshape_fusion.cc @@ -46,9 +46,9 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph, const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(equiv); - auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); + auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputTensorNum); MS_EXCEPTION_IF_NULL(reshape_cnode); - auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum); + auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputTensorNum); MS_EXCEPTION_IF_NULL(transpose_cnode); if (AnfAlgo::IsDynamicShape(transpose_cnode) || AnfAlgo::IsDynamicShape(reshape_cnode)) { return nullptr; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.cc index e757aa1290..3b521ea277 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/transpose_transdata_fusion.cc @@ -33,9 +33,9 @@ const AnfNodePtr TransposeTransDataFusion::Process(const FuncGraphPtr &func_grap const EquivPtr &equiv) const { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(equiv); - auto transdata_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendTransposeInputNum); + auto transdata_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kBackendTransposeInputTensorNum); MS_EXCEPTION_IF_NULL(transdata_cnode); - auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(transdata_cnode->input(1), kBackendTransDataInputNum); + auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(transdata_cnode->input(1), kTransOpInputTensorNum); MS_EXCEPTION_IF_NULL(transpose_cnode); auto transpose_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(transpose_cnode); auto transdata_kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(transdata_cnode); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc index f31fa66793..559174e557 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/conv2d_unify_mindir.cc @@ -136,10 +136,7 @@ CNodePtr CreateTranspose(const FuncGraphPtr &graph, const CNodePtr &conv2d, cons CNodePtr CreateDepthwiseConv2D(const FuncGraphPtr &graph, const CNodePtr &conv2d, const CNodePtr &transpose) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(conv2d); - if (conv2d->inputs().size() != kConvInputNum) { - MS_LOG(EXCEPTION) << "Conv2D's input number should be " << kConvInputNum - 1 << ", but got " - << conv2d->inputs().size() - 1; - } + CheckCNodeInputSize(conv2d, kConvInputTensorNum); std::vector depth_conv_inputs = {NewValueNode(std::make_shared(kDepthwiseConv2dNativeOpName)), conv2d->input(1), transpose}; auto depth_conv = graph->NewCNode(depth_conv_inputs); @@ -270,11 +267,7 @@ const AnfNodePtr Conv2DUnifyMindIR::Process(const FuncGraphPtr &graph, const Anf if (!NeedUpdate(conv2d, input_shape, output_shape)) { return nullptr; } - - if (conv2d->inputs().size() != kConvInputNum) { - MS_LOG(EXCEPTION) << "Conv2D's input number should be " << kConvInputNum - 1 << ", but got " - << conv2d->inputs().size() - 1; - } + CheckCNodeInputSize(conv2d, kConvInputTensorNum); auto transpose = CreateTranspose(graph, conv2d, conv2d->input(2), true); auto depth_conv = CreateDepthwiseConv2D(graph, conv2d, transpose); SetConv2DAttrs(conv2d, depth_conv); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/optimizer_unify_output.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/optimizer_unify_output.cc index 14799217f1..597ea4871d 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/optimizer_unify_output.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/optimizer_unify_output.cc @@ -70,7 +70,8 @@ const BaseRef FtrlUnifyOutput::DefinePattern() const { VarPtr l1 = std::make_shared(); VarPtr l2 = std::make_shared(); VarPtr lr_power = std::make_shared(); - VectorRef pattern({prim::kPrimApplyFtrl, var, accum, linear, grad, lr, l1, l2, lr_power}); + VarPtr u = std::make_shared(); + VectorRef pattern({prim::kPrimApplyFtrl, var, accum, linear, grad, lr, l1, l2, lr_power, u}); return pattern; } @@ -84,7 +85,8 @@ const BaseRef MomentumUnifyOutput::DefinePattern() const { VarPtr lr = std::make_shared(); VarPtr grad = std::make_shared(); VarPtr momentum = std::make_shared(); - VectorRef pattern({prim::kPrimApplyMomentum, var, accum, lr, grad, momentum}); + VarPtr u = std::make_shared(); + VectorRef pattern({prim::kPrimApplyMomentum, var, accum, lr, grad, momentum, u}); return pattern; } @@ -114,7 +116,8 @@ const BaseRef CenteredRMSPropUnifyOutput::DefinePattern() const { VarPtr rho = std::make_shared(); VarPtr momentum = std::make_shared(); VarPtr epsilon = std::make_shared(); - VectorRef pattern({prim::kPrimApplyCenteredRMSProp, var, mg, ms, mom, grad, lr, rho, momentum, epsilon}); + VarPtr u = std::make_shared(); + VectorRef pattern({prim::kPrimApplyCenteredRMSProp, var, mg, ms, mom, grad, lr, rho, momentum, epsilon, u}); return pattern; } diff --git a/mindspore/ccsrc/backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc b/mindspore/ccsrc/backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc index f937963b15..be88236141 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/mindir/sparse_softmax_cross_entropy_with_logits_unify_mindir.cc @@ -109,12 +109,7 @@ CNodePtr CreateSoftmaxCrossEntropyWithLogits(const FuncGraphPtr &graph, const CN MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(sparse_softmax_node); MS_EXCEPTION_IF_NULL(one_hot_node); - - if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { - MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal " - << kSparseSoftmaxCrossEntropyWithLogitsInputNum; - } - + CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum); std::vector inputs = {NewValueNode(std::make_shared(kSoftmaxCrossEntropyWithLogitsOpName)), sparse_softmax_node->input(1), one_hot_node}; auto softmax_node = graph->NewCNode(inputs); @@ -162,10 +157,7 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(sparse_softmax_node); MS_EXCEPTION_IF_NULL(softmax_output_node); - if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { - MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal " - << kSparseSoftmaxCrossEntropyWithLogitsInputNum; - } + CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum); auto axis_value = GetAxis(softmax_output_node); auto axis_node = GetAxisNode(softmax_output_node); @@ -200,9 +192,7 @@ CNodePtr CreateReduceMean(const FuncGraphPtr &graph, const CNodePtr &sparse_soft CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_node) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(real_div_node); - if (real_div_node->size() != kRealDivInputNum) { - MS_LOG(EXCEPTION) << "Op real_div's input num not equal " << kRealDivInputNum; - } + CheckCNodeInputSize(real_div_node, kRealDivInputTensorNum); int64_t axis = -1; auto axis_node = NewValueNode(axis); @@ -230,9 +220,8 @@ CNodePtr CreateExpandDims(const FuncGraphPtr &graph, const CNodePtr &real_div_no CNodePtr CreateExpandDimsPynative(const FuncGraphPtr &graph, const CNodePtr &real_div_node) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(real_div_node); - if (real_div_node->size() != kRealDivInputNum) { - MS_LOG(EXCEPTION) << "Op real_div's input num not equal " << kRealDivInputNum; - } + CheckCNodeInputSize(real_div_node, kRealDivInputTensorNum); + int64_t axis = -1; auto expand_dims_primitive = std::make_shared(kExpandDimsOpName); std::vector input_names = {"x"}; @@ -257,13 +246,8 @@ CNodePtr CreateTile(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax_no MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(sparse_softmax_node); MS_EXCEPTION_IF_NULL(mul_node); - if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { - MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal " - << kSparseSoftmaxCrossEntropyWithLogitsInputNum; - } - if (mul_node->size() != kMulInputNum) { - MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum; - } + CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum); + CheckCNodeInputSize(mul_node, kMulInputTensorNum); auto labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, 1); std::vector multiple_value; @@ -310,12 +294,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(sparse_softmax_node); MS_EXCEPTION_IF_NULL(tile_node); - - if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { - MS_LOG(EXCEPTION) << "sparse_softmax_cross_entropy_with_logits's input size not equal " - << kSparseSoftmaxCrossEntropyWithLogitsInputNum; - } - + CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum); std::vector labels_shape = AnfAlgo::GetPrevNodeOutputInferShape(sparse_softmax_node, 1); if (labels_shape.size() != 1) { MS_LOG(EXCEPTION) << "label's shape should be 1-D."; @@ -343,9 +322,7 @@ CNodePtr CreateRealDiv(const FuncGraphPtr &graph, const CNodePtr &sparse_softmax CNodePtr GetSparseNode(const CNodePtr &depend_node, size_t index) { MS_EXCEPTION_IF_NULL(depend_node); - if (depend_node->size() != kDependInputNum) { - MS_LOG(EXCEPTION) << "Op Depend's input not equal " << kDependInputNum; - } + CheckCNodeInputSize(depend_node, kDependInputTensorNum); auto sparse_node = depend_node->input(index); MS_EXCEPTION_IF_NULL(sparse_node); return sparse_node->cast(); @@ -353,9 +330,7 @@ CNodePtr GetSparseNode(const CNodePtr &depend_node, size_t index) { CNodePtr GetDependNode(const CNodePtr &mul_node) { MS_EXCEPTION_IF_NULL(mul_node); - if (mul_node->size() != kMulInputNum) { - MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum; - } + CheckCNodeInputSize(mul_node, kMulInputTensorNum); auto depend_node = mul_node->input(1); MS_EXCEPTION_IF_NULL(depend_node); return depend_node->cast(); @@ -413,10 +388,7 @@ const AnfNodePtr SparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(const F auto sparse_softmax_node = node->cast(); MS_EXCEPTION_IF_NULL(sparse_softmax_node); - if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { - MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal " - << kSparseSoftmaxCrossEntropyWithLogitsInputNum; - } + CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum); if (AnfAlgo::HasNodeAttr(kAttrIsGrad, sparse_softmax_node) && AnfAlgo::GetNodeAttr(sparse_softmax_node, kAttrIsGrad)) { return nullptr; @@ -451,17 +423,12 @@ const AnfNodePtr GradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process(con auto mul_node = node->cast(); MS_EXCEPTION_IF_NULL(mul_node); - if (mul_node->size() != kMulInputNum) { - MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum; - } + CheckCNodeInputSize(mul_node, kMulInputTensorNum); auto depend_node = GetDependNode(mul_node); auto sparse_softmax_node = GetSparseNode(depend_node, 2); auto sparse_softmax_node_grad = GetSparseNode(depend_node, 1); - if (sparse_softmax_node_grad->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { - MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal " - << kSparseSoftmaxCrossEntropyWithLogitsInputNum; - } + CheckCNodeInputSize(sparse_softmax_node_grad, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum); CNodePtr softmax_node; auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad); @@ -538,10 +505,8 @@ const AnfNodePtr PynativeSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Process auto sparse_softmax_node = node->cast(); MS_EXCEPTION_IF_NULL(sparse_softmax_node); - if (sparse_softmax_node->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { - MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal " - << kSparseSoftmaxCrossEntropyWithLogitsInputNum; - } + CheckCNodeInputSize(sparse_softmax_node, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum); + if (AnfAlgo::HasNodeAttr(kAttrIsGrad, sparse_softmax_node) && AnfAlgo::GetNodeAttr(sparse_softmax_node, kAttrIsGrad)) { return nullptr; @@ -573,17 +538,12 @@ const AnfNodePtr PynativeGradSparseSoftmaxCrossEntropyWithLogitsUnifyMindIR::Pro auto mul_node = node->cast(); MS_EXCEPTION_IF_NULL(mul_node); - if (mul_node->size() != kMulInputNum) { - MS_LOG(EXCEPTION) << "Op Mul's input not equal " << kMulInputNum; - } + CheckCNodeInputSize(mul_node, kMulInputTensorNum); + auto sparse_softmax_node = mul_node->input(1); auto sparse_softmax_node_grad = sparse_softmax_node->cast(); MS_EXCEPTION_IF_NULL(sparse_softmax_node_grad); - - if (sparse_softmax_node_grad->size() != kSparseSoftmaxCrossEntropyWithLogitsInputNum) { - MS_LOG(EXCEPTION) << "Op SparseSoftmaxCrossEntropyWithLogits's input not equal " - << kSparseSoftmaxCrossEntropyWithLogitsInputNum; - } + CheckCNodeInputSize(sparse_softmax_node_grad, kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum); CNodePtr softmax_node; auto one_hot_node = CreateOneHot(graph, sparse_softmax_node_grad); diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.cc b/mindspore/ccsrc/backend/optimizer/common/helper.cc index 8b89a536bf..5c95bd2b76 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.cc +++ b/mindspore/ccsrc/backend/optimizer/common/helper.cc @@ -124,18 +124,16 @@ CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_si MS_LOG(EXCEPTION) << "The node is expected to be a cnode"; } auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() != input_size) { - auto op_name = AnfAlgo::GetCNodeName(cnode); - MS_LOG(EXCEPTION) << "op[" + op_name + "] has less than " << input_size << " inputs."; - } + CheckCNodeInputSize(cnode, input_size); return cnode; } -void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size) { +void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_size) { MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() != input_size) { - MS_LOG(EXCEPTION) << "The input size of node " + cnode->DebugString() + " is not equal to " << input_size; + auto real_input_tensor_num = AnfAlgo::GetInputTensorNum(cnode); + if (real_input_tensor_num != input_tensor_size) { + MS_LOG(EXCEPTION) << "The input tensor size[" << real_input_tensor_num + << "] of node " + cnode->DebugString() + " is not equal to " << input_tensor_size; } } @@ -149,17 +147,15 @@ bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y const AnfNodePtr EliminateDependTransop(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(func_graph); - auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputNum); + auto transop_cnode = CheckAnfNodeIfCNodeAndInputSize(node, kTransOpInputTensorNum); MS_EXCEPTION_IF_NULL(transop_cnode); - auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(kCastInputNum - 1), kDependInputNum); - auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputNum); - MS_EXCEPTION_IF_NULL(depend_cnode->input(kDependInputNum - 1)); - MS_EXCEPTION_IF_NULL(prev_transop_cnode->input(kTransOpInputNum - 1)); - auto transed_node = prev_transop_cnode->input(kTransOpInputNum - 1); + auto depend_cnode = CheckAnfNodeIfCNodeAndInputSize(transop_cnode->input(1), kDependInputTensorNum); + auto prev_transop_cnode = CheckAnfNodeIfCNodeAndInputSize(depend_cnode->input(1), kTransOpInputTensorNum); + auto transed_node = prev_transop_cnode->input(1); MS_EXCEPTION_IF_NULL(transed_node); std::vector replace_depend_inputs{NewValueNode(prim::kPrimDepend), transed_node, - depend_cnode->input(kDependInputNum - 1)}; + depend_cnode->input(kDependAttachNodeIndex)}; AnfNodePtr replace_depend = func_graph->NewCNode(replace_depend_inputs); MS_EXCEPTION_IF_NULL(replace_depend); auto transed_abstract = transed_node->abstract(); @@ -422,13 +418,13 @@ std::shared_ptr>> GetRealNodeUsedList(con } auto output_info_list = iter->second; for (const auto &output_info : output_info_list) { - if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimControlDepend->name()) { - continue; - } if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimDepend->name() && output_info.second == kDependAttachNodeIndex) { continue; } + if (AnfAlgo::GetCNodeName(output_info.first) == prim::kPrimUpdateState->name()) { + continue; + } output_node_list->push_back(output_info); } return output_node_list; @@ -537,6 +533,9 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &i bool need_update = false; for (size_t i = 0; i < inputs.size() - 1; ++i) { auto input_node = inputs[i + 1]; + if (AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimDepend)) { + input_node = AnfAlgo::VisitKernel(input_node, 0).first; + } MS_EXCEPTION_IF_NULL(input_node); if (input_attrs.find(i) != input_attrs.end() && input_node->isa()) { auto value_node = input_node->cast(); @@ -548,7 +547,7 @@ void ConstInputToAttr(const CNodePtr &cnode, const std::unordered_set &i primitive->set_attr(input_names_vec[i], value_node->value()); need_update = true; } else { - new_inputs.push_back(input_node); + new_inputs.push_back(inputs[i + 1]); } } if (need_update) { @@ -785,7 +784,8 @@ ValueNodePtr MakeValueNode(const ValueNodePtr &value_node) { kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); // set value node initial device data type = infer data type std::vector types; - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(value_node); + for (size_t index = 0; index < output_num; ++index) { types.push_back(kTypeUnknown); } kernel_build_info_builder->SetOutputsDeviceType(types); diff --git a/mindspore/ccsrc/backend/optimizer/common/helper.h b/mindspore/ccsrc/backend/optimizer/common/helper.h index 55162164ef..ee3af39da3 100644 --- a/mindspore/ccsrc/backend/optimizer/common/helper.h +++ b/mindspore/ccsrc/backend/optimizer/common/helper.h @@ -29,36 +29,34 @@ namespace mindspore { namespace opt { -constexpr size_t kTransOpInputNum = 2; -constexpr size_t kCastInputNum = 2; -constexpr size_t kDependInputNum = 3; -constexpr size_t kReluInputNum = 2; -constexpr size_t kReluGradInputNum = 3; -constexpr size_t kAddInputNum = 3; -constexpr size_t kAddNInputNum = 3; -constexpr size_t kTupleGetitemInputNum = 3; -constexpr size_t kConvInputNum = 3; -constexpr size_t kRealDivInputNum = 3; -constexpr size_t kSqrtInputNum = 2; -constexpr size_t kMulInputNum = 3; -constexpr size_t kRsqrtInputNum = 2; -constexpr size_t kSubInputNum = 3; -constexpr size_t kAssignSubInputNum = 3; -constexpr size_t kDropoutInputNum = 2; +constexpr size_t kTransOpInputTensorNum = 1; +constexpr size_t kCastInputTensorNum = 1; +constexpr size_t kDependInputTensorNum = 2; +constexpr size_t kReluInputTensorNum = 1; +constexpr size_t kReluGradInputTensorNum = 2; +constexpr size_t kAddInputTensorNum = 2; +constexpr size_t kTupleGetItemInputTensorNum = 2; +constexpr size_t kConvInputTensorNum = 2; +constexpr size_t kRealDivInputTensorNum = 2; +constexpr size_t kSqrtInputTensorNum = 1; +constexpr size_t kMatMulInputTensorNum = 2; +constexpr size_t kMulInputTensorNum = 2; +constexpr size_t kSubInputTensorNum = 2; +constexpr size_t kAssignSubInputTensorNum = 2; +constexpr size_t kDropoutInputTensorNum = 1; +constexpr size_t kAssignInputTensorNum = 2; constexpr size_t kConvBn1OutputNum = 3; constexpr size_t kBn2ReluOutputNum = 4; -constexpr size_t kBnInputNum = 6; +constexpr size_t kBnInputTensorNum = 5; constexpr size_t kBnOutputNum = 5; -constexpr size_t kBatchNormInputNum = 5; -constexpr size_t kBatchNormOutputNum = 5; constexpr size_t kBN1OutputNum = 2; constexpr size_t kBN2OutputNum = 3; constexpr size_t kBN3OutputNum = 1; -constexpr size_t kBNGradInputNum = 6; +constexpr size_t kBNGradInputTensorNum = 5; constexpr size_t kBNGradOutputNum = 3; constexpr size_t kBNGrad1OutputNum = 3; @@ -72,10 +70,10 @@ constexpr size_t kBNTrainingUpdateV3OutputNum = 5; constexpr size_t kBNTrainingUpdateGradOutputNum = 2; constexpr size_t kSingleOutputNum = 1; -constexpr size_t kSumNodeInputNum = 2; -constexpr size_t kSquareNodeInputNum = 2; +constexpr size_t kSumNodeInputTensorNum = 1; +constexpr size_t kSquareNodeInputTensorNum = 1; constexpr size_t kSquareSumv2OutputNum = 2; -constexpr size_t kMinimumInputNum = 3; +constexpr size_t kMinimumInputTensorNum = 2; constexpr size_t kLambNextMVWithDecayInputNum = 7; constexpr size_t kLambNextMVWithDecayConstantMulInputNum = 5; @@ -85,26 +83,25 @@ constexpr size_t kLambNextRightOutputNum = 2; constexpr size_t kLambUpdateWithLrV2InputNum = 8; constexpr size_t kLambNextMVRuleInputNum = 14; constexpr size_t kLambNextMVRuleOutputNum = 4; -constexpr size_t kBackendReshapeInputNum = 2; -constexpr size_t kBackendTransposeInputNum = 2; +constexpr size_t kBackendReshapeInputTensorNum = 1; +constexpr size_t kBackendTransposeInputTensorNum = 1; constexpr size_t kAdamApplyOneWithDecayOutputNum = 3; -constexpr size_t kLayerNormBetaGammaBackpropInputNum = 5; +constexpr size_t kLayerNormBetaGammaBackpropInputTensorNum = 4; constexpr size_t kLayerNormBetaGammaBackpropOutputNum = 2; -constexpr size_t kLayerNormGradInputNum = 6; +constexpr size_t kLayerNormGradInputTensorNum = 5; constexpr size_t kAdamApplyOneOutputNum = 3; -constexpr size_t kBackendTransDataInputNum = 2; -constexpr size_t kApplyMomentumInputNum = 6; -constexpr size_t kBiasAddInputNum = 3; -constexpr size_t kTopkInputNum = 3; -constexpr size_t kLarsV2InputNum = 5; +constexpr size_t kApplyMomentumInputTensorNum = 5; +constexpr size_t kBiasAddInputTensorNum = 2; +constexpr size_t kTopkInputTensorNum = 2; +constexpr size_t kLarsV2InputTensorNum = 4; constexpr size_t kFusedMulApplyMomentumOutputNum = 2; -constexpr size_t kSplitInputNum = 2; -constexpr size_t kGatherV2DynInputNum = 3; -constexpr size_t kUnsortedSegmentSumInputNum = 2; +constexpr size_t kSplitInputTensorNum = 1; +constexpr size_t kGatherV2DynInputTensorNum = 3; +constexpr size_t kUnsortedSegmentSumInputTensorNum = 2; constexpr size_t kSoftmaxCrossEntropyWithLogitsOutputNum = 2; -constexpr size_t kSparseSoftmaxCrossEntropyWithLogitsInputNum = 3; +constexpr size_t kSparseSoftmaxCrossEntropyWithLogitsInputTensorNum = 2; constexpr size_t kOneHotOutputNum = 1; -constexpr size_t kOneHotInputNum = 5; +constexpr size_t kOneHotInputTensorNum = 4; enum FusedBatchNormInput { kX = 1, @@ -137,7 +134,7 @@ bool Visited(const BaseRef &n); // check if the input node is CNode, then check it's input_size, return CNodePtr if check success. CNodePtr CheckAnfNodeIfCNodeAndInputSize(const AnfNodePtr &node, size_t input_size); -void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_size); +void CheckCNodeInputSize(const CNodePtr &cnode, size_t input_tensor_num); bool HasSymmetricalKernelInfo(const AnfNodePtr &node_x, const AnfNodePtr &node_y); diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc index a170d2002c..9caedbe23c 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.cc @@ -34,11 +34,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { std::vector outputs_type; kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); inputs_format.push_back(kOpFormat_DEFAULT); } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); outputs_format.push_back(kOpFormat_DEFAULT); } @@ -51,19 +53,30 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { } // namespace const BaseRef AdamFusion::DefinePattern() const { - VectorRef next_m = VectorRef( - {prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, m_}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); + VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_}); + VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}), + VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); + + VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_}); VectorRef next_v = - VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, v_}), + VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}), VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); + VectorRef update = VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, update}); VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); - next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})}); - next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})}); - next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})}); + VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, u2_}); + VectorRef next_state = VectorRef({prim::kPrimUpdateState, u2_, assign_param}); + next_param = VectorRef({prim::kPrimDepend, next_param, assign_param}); + + VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state}); + next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_m}); + next_param = VectorRef({prim::kPrimDepend, next_param, assign_m}); + + VectorRef assign_v = VectorRef({prim::kPrimAssign, v_, next_v, next_state}); + next_param = VectorRef({prim::kPrimDepend, next_param, assign_v}); return next_param; } @@ -81,6 +94,7 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr auto m_input = utils::cast((*equiv)[m_]); auto v_input = utils::cast((*equiv)[v_]); auto gradient_input = utils::cast((*equiv)[gradient_]); + auto u_input = utils::cast((*equiv)[u_]); MS_EXCEPTION_IF_NULL(beta1_input); MS_EXCEPTION_IF_NULL(one_sub_beta1_input); MS_EXCEPTION_IF_NULL(beta2_input); @@ -91,13 +105,30 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr MS_EXCEPTION_IF_NULL(m_input); MS_EXCEPTION_IF_NULL(v_input); MS_EXCEPTION_IF_NULL(gradient_input); + MS_EXCEPTION_IF_NULL(u_input); + + // Use depend(param, u) to maintain the execution order of FusedAdam and the previous operators. + auto prim_depend = std::make_shared(prim::kPrimDepend->name()); + MS_EXCEPTION_IF_NULL(prim_depend); + std::vector param_inputs = {NewValueNode(prim_depend), param_input, u_input}; + auto param = graph->NewCNode(param_inputs); + MS_EXCEPTION_IF_NULL(param); + param->set_abstract(param_input->abstract()); + // Fused into a FusedAdam operator. auto prim = std::make_shared(kFusedAdamName); MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = { - NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, - eps_input, lr_input, param_input, m_input, v_input, - gradient_input}; + std::vector inputs = {NewValueNode(prim), + beta1_input, + one_sub_beta1_input, + beta2_input, + one_sub_beta2_input, + eps_input, + lr_input, + param, + m_input, + v_input, + gradient_input}; auto adam = graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(adam); auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; @@ -107,6 +138,30 @@ const AnfNodePtr AdamFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr auto build_info = GenerateKernelBuildInfo(adam); AnfAlgo::SetSelectKernelBuildInfo(build_info, adam.get()); + + // Replace the parameters of the last UpdateState to maintain + // the execution order of FusedAdam and the following operators. + // n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v} + auto n = node->cast()->input(2); + auto fg = n->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + auto mgr = fg->manager(); + MS_EXCEPTION_IF_NULL(mgr); + auto &node_users = mgr->node_users(); + auto iter = node_users.find(n); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString(); + } + + auto &users = iter->second; + for (auto &user : users) { + if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) { + (user.first)->cast()->set_input(1, u_input); + (user.first)->cast()->set_input(2, adam); + break; + } + } + return adam; } } // namespace opt diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h index 1fa339c3f3..a83c4281da 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_fusion.h @@ -34,6 +34,8 @@ class AdamFusion : public PatternProcessPass { m_ = std::make_shared(); v_ = std::make_shared(); gradient_ = std::make_shared(); + u_ = std::make_shared(); + u2_ = std::make_shared(); } ~AdamFusion() override = default; const BaseRef DefinePattern() const override; @@ -50,6 +52,8 @@ class AdamFusion : public PatternProcessPass { VarPtr m_; VarPtr v_; VarPtr gradient_; + VarPtr u_; + VarPtr u2_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc index 89cefe9f1b..0a2924790c 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.cc @@ -34,11 +34,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { std::vector outputs_type; kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); inputs_format.push_back(kOpFormat_DEFAULT); } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); outputs_format.push_back(kOpFormat_DEFAULT); } @@ -51,11 +53,14 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { } // namespace const BaseRef AdamWeightDecayFusion::DefinePattern() const { - VectorRef next_m = VectorRef( - {prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, m_}), VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); + VectorRef load_m = VectorRef({prim::kPrimLoad, m_, u_}); + VectorRef next_m = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta1_, load_m}), + VectorRef({prim::kPrimMul, one_sub_beta1_, gradient_})}); + VectorRef load_v = VectorRef({prim::kPrimLoad, v_, u_}); VectorRef next_v = - VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, v_}), + VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, beta2_, load_v}), VectorRef({prim::kPrimMul, one_sub_beta2_, VectorRef({prim::kPrimSquare, gradient_})})}); + VectorRef update = VectorRef({prim::kPrimRealDiv, next_m, VectorRef({prim::kPrimAdd, eps_, VectorRef({prim::kPrimSqrt, next_v})})}); VectorRef new_update = VectorRef({prim::kPrimAdd, VectorRef({prim::kPrimMul, weight_decay_, param_}), update}); @@ -63,9 +68,16 @@ const BaseRef AdamWeightDecayFusion::DefinePattern() const { VectorRef update_with_lr = VectorRef({prim::kPrimMul, lr_, new_update}); VectorRef next_param = VectorRef({prim::kPrimSub, param_, update_with_lr}); - next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, param_, next_param})}); - next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, m_, next_m})}); - next_param = VectorRef({prim::kPrimDepend, next_param, VectorRef({prim::kPrimAssign, v_, next_v})}); + VectorRef assign_param = VectorRef({prim::kPrimAssign, param_, next_param, u2_}); + VectorRef next_state = VectorRef({prim::kPrimUpdateState, u2_, assign_param}); + next_param = VectorRef({prim::kPrimDepend, next_param, assign_param}); + + VectorRef assign_m = VectorRef({prim::kPrimAssign, m_, next_m, next_state}); + next_state = VectorRef({prim::kPrimUpdateState, next_state, assign_m}); + next_param = VectorRef({prim::kPrimDepend, next_param, assign_m}); + + VectorRef assign_v = VectorRef({prim::kPrimAssign, v_, next_v, next_state}); + next_param = VectorRef({prim::kPrimDepend, next_param, assign_v}); return next_param; } @@ -85,6 +97,7 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const auto m_input = utils::cast((*equiv)[m_]); auto v_input = utils::cast((*equiv)[v_]); auto gradient_input = utils::cast((*equiv)[gradient_]); + auto u_input = utils::cast((*equiv)[u_]); MS_EXCEPTION_IF_NULL(beta1_input); MS_EXCEPTION_IF_NULL(one_sub_beta1_input); MS_EXCEPTION_IF_NULL(beta2_input); @@ -96,13 +109,31 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const MS_EXCEPTION_IF_NULL(m_input); MS_EXCEPTION_IF_NULL(v_input); MS_EXCEPTION_IF_NULL(gradient_input); + MS_EXCEPTION_IF_NULL(u_input); + // Use depend(param, u) to maintain the execution order of FusedAdamWeightDecay and the previous operators. + auto prim_depend = std::make_shared(prim::kPrimDepend->name()); + MS_EXCEPTION_IF_NULL(prim_depend); + std::vector param_inputs = {NewValueNode(prim_depend), param_input, u_input}; + auto param = graph->NewCNode(param_inputs); + MS_EXCEPTION_IF_NULL(param); + param->set_abstract(param_input->abstract()); + + // Fused into a FusedAdamWeightDecay operator. auto prim = std::make_shared(kFusedAdamWeightDecayName); MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = { - NewValueNode(prim), beta1_input, one_sub_beta1_input, beta2_input, one_sub_beta2_input, - eps_input, lr_input, param_input, m_input, v_input, - gradient_input, weight_decay_input}; + std::vector inputs = {NewValueNode(prim), + beta1_input, + one_sub_beta1_input, + beta2_input, + one_sub_beta2_input, + eps_input, + lr_input, + param, + m_input, + v_input, + gradient_input, + weight_decay_input}; auto adam_weight_decay = graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(adam_weight_decay); auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; @@ -112,6 +143,30 @@ const AnfNodePtr AdamWeightDecayFusion::Process(const FuncGraphPtr &graph, const auto build_info = GenerateKernelBuildInfo(adam_weight_decay); AnfAlgo::SetSelectKernelBuildInfo(build_info, adam_weight_decay.get()); + + // Replace the parameters of the last UpdateState to maintain + // the execution order of FusedAdamWeightDecay and the following operators. + // n represents the operator assign_v in {prim::kPrimDepend, next_param, assign_v} + auto n = node->cast()->input(2); + auto fg = n->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + auto mgr = fg->manager(); + MS_EXCEPTION_IF_NULL(mgr); + auto &node_users = mgr->node_users(); + auto iter = node_users.find(n); + if (iter == node_users.end()) { + MS_LOG(EXCEPTION) << "Can not find node : " << n->DebugString(); + } + + auto &users = iter->second; + for (auto &user : users) { + if (IsPrimitiveCNode(user.first, prim::kPrimUpdateState)) { + (user.first)->cast()->set_input(1, u_input); + (user.first)->cast()->set_input(2, adam_weight_decay); + break; + } + } + return adam_weight_decay; } } // namespace opt diff --git a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h index 015ce63206..d29cefb222 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/gpu/adam_weight_decay_fusion.h @@ -35,6 +35,8 @@ class AdamWeightDecayFusion : public PatternProcessPass { m_ = std::make_shared(); v_ = std::make_shared(); gradient_ = std::make_shared(); + u_ = std::make_shared(); + u2_ = std::make_shared(); } ~AdamWeightDecayFusion() override = default; const BaseRef DefinePattern() const override; @@ -52,6 +54,8 @@ class AdamWeightDecayFusion : public PatternProcessPass { VarPtr m_; VarPtr v_; VarPtr gradient_; + VarPtr u_; + VarPtr u2_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc index e35f97fa6e..2308ae6dab 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_grad_v2_fusion.cc @@ -34,11 +34,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { std::vector outputs_type; kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); inputs_format.push_back(kOpFormat_DEFAULT); } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); outputs_format.push_back(kOpFormat_DEFAULT); } diff --git a/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc index f643b2869e..1cea7134c4 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/add_relu_v2_fusion.cc @@ -34,11 +34,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { std::vector outputs_type; kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); inputs_format.push_back(kOpFormat_DEFAULT); } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); outputs_format.push_back(kOpFormat_DEFAULT); } @@ -78,7 +80,8 @@ const AnfNodePtr AddReluV2Fusion::Process(const FuncGraphPtr &graph, const AnfNo std::vector types; std::vector> shapes; - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); i++) { + size_t output_num = AnfAlgo::GetOutputTensorNum(node); + for (size_t i = 0; i < output_num; i++) { types.push_back(AnfAlgo::GetOutputInferDataType(node, i)); shapes.push_back(AnfAlgo::GetOutputInferShape(node, i)); } diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc index 1c139c3660..d19667089d 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.cc @@ -51,7 +51,7 @@ bool ApplyMomentumScaleFusion::IsScalar(const BaseRef &n) { const BaseRef ApplyMomentumScaleFusion::DefinePattern() const { VectorRef scale = VectorRef({prim::kPrimMul, gradient_, scale_}); VectorRef apply_momentum = - VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_}); + VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_, monad_state_}); return apply_momentum; } @@ -66,17 +66,19 @@ const AnfNodePtr ApplyMomentumScaleFusion::Process(const FuncGraphPtr &graph, co auto learning_rate = utils::cast((*equiv)[learning_rate_]); auto gradient = utils::cast((*equiv)[gradient_]); auto momentum = utils::cast((*equiv)[momentum_]); + auto monad_state = utils::cast((*equiv)[monad_state_]); MS_EXCEPTION_IF_NULL(scale); MS_EXCEPTION_IF_NULL(variable); MS_EXCEPTION_IF_NULL(accumulation); MS_EXCEPTION_IF_NULL(learning_rate); MS_EXCEPTION_IF_NULL(gradient); MS_EXCEPTION_IF_NULL(momentum); + MS_EXCEPTION_IF_NULL(monad_state); auto prim = std::make_shared(kFusedScaleApplyMomentum); MS_EXCEPTION_IF_NULL(prim); std::vector inputs = {NewValueNode(prim), scale, variable, accumulation, - learning_rate, gradient, momentum}; + learning_rate, gradient, momentum, monad_state}; auto replace_node = graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(replace_node); auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h index 8888f40c7b..1271492d38 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_scale_fusion.h @@ -31,6 +31,7 @@ class ApplyMomentumScaleFusion : public PatternProcessPass { learning_rate_ = std::make_shared(); gradient_ = std::make_shared(); momentum_ = std::make_shared(); + monad_state_ = std::make_shared(); } ~ApplyMomentumScaleFusion() override = default; const BaseRef DefinePattern() const override; @@ -45,6 +46,7 @@ class ApplyMomentumScaleFusion : public PatternProcessPass { VarPtr learning_rate_; VarPtr gradient_; VarPtr momentum_; + VarPtr monad_state_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_fusion.cc index 3c851bcde4..febfa3b35f 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_fusion.cc @@ -49,10 +49,11 @@ bool ApplyMomentumWeightDecayFusion::IsScalar(const BaseRef &n) { } const BaseRef ApplyMomentumWeightDecayFusion::DefinePattern() const { + VectorRef load_para = VectorRef({prim::kPrimLoad, variable_, monad_}); VectorRef weight_decay = - VectorRef({prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), gradient_}); - VectorRef apply_momentum = - VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, weight_decay, momentum_}); + VectorRef({prim::kPrimAddN, VectorRef({prim::kPrimMul, load_para, weight_decay_}), gradient_}); + VectorRef apply_momentum = VectorRef( + {prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, weight_decay, momentum_, monad_state_}); return apply_momentum; } @@ -67,17 +68,19 @@ const AnfNodePtr ApplyMomentumWeightDecayFusion::Process(const FuncGraphPtr &gra auto learning_rate = utils::cast((*equiv)[learning_rate_]); auto gradient = utils::cast((*equiv)[gradient_]); auto momentum = utils::cast((*equiv)[momentum_]); + auto monad_state = utils::cast((*equiv)[monad_state_]); MS_EXCEPTION_IF_NULL(weight_decay); MS_EXCEPTION_IF_NULL(variable); MS_EXCEPTION_IF_NULL(accumulation); MS_EXCEPTION_IF_NULL(learning_rate); MS_EXCEPTION_IF_NULL(gradient); MS_EXCEPTION_IF_NULL(momentum); + MS_EXCEPTION_IF_NULL(monad_state); auto prim = std::make_shared(kFusedWeightApplyMomentum); MS_EXCEPTION_IF_NULL(prim); std::vector inputs = {NewValueNode(prim), weight_decay, variable, accumulation, - learning_rate, gradient, momentum}; + learning_rate, gradient, momentum, monad_state}; auto replace_node = graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(replace_node); auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_fusion.h index 1b49394e43..6749b58aa9 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_fusion.h @@ -25,12 +25,14 @@ class ApplyMomentumWeightDecayFusion : public PatternProcessPass { public: explicit ApplyMomentumWeightDecayFusion(bool multigraph = true) : PatternProcessPass("momentum_weightdecay_fusion", multigraph) { + monad_ = std::make_shared(); weight_decay_ = std::make_shared(); variable_ = std::make_shared(); accumulation_ = std::make_shared(); learning_rate_ = std::make_shared(); gradient_ = std::make_shared(); momentum_ = std::make_shared(); + monad_state_ = std::make_shared(); } ~ApplyMomentumWeightDecayFusion() override = default; const BaseRef DefinePattern() const override; @@ -39,12 +41,14 @@ class ApplyMomentumWeightDecayFusion : public PatternProcessPass { private: static bool IsScalar(const BaseRef &n); + VarPtr monad_; VarPtr weight_decay_; VarPtr variable_; VarPtr accumulation_; VarPtr learning_rate_; VarPtr gradient_; VarPtr momentum_; + VarPtr monad_state_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.cc index 743015c50c..4c206712e2 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.cc @@ -49,11 +49,12 @@ bool ApplyMomentumWeightDecayScaleFusion::IsScalar(const BaseRef &n) { } const BaseRef ApplyMomentumWeightDecayScaleFusion::DefinePattern() const { + VectorRef load_para = VectorRef({prim::kPrimLoad, variable_, monad_}); VectorRef weight = VectorRef( - {prim::kPrimAddN, VectorRef({prim::kPrimMul, variable_, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})}); + {prim::kPrimAddN, VectorRef({prim::kPrimMul, load_para, weight_decay_}), VectorRef({prim::kPrimCast, gradient_})}); VectorRef scale = VectorRef({prim::kPrimMul, weight, scale_}); VectorRef apply_momentum = - VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_}); + VectorRef({prim::kPrimApplyMomentum, variable_, accumulation_, learning_rate_, scale, momentum_, monad_state_}); return apply_momentum; } @@ -69,6 +70,8 @@ const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr auto learning_rate = utils::cast((*equiv)[learning_rate_]); auto gradient = utils::cast((*equiv)[gradient_]); auto momentum = utils::cast((*equiv)[momentum_]); + auto monad_state = utils::cast((*equiv)[monad_state_]); + MS_EXCEPTION_IF_NULL(weight_decay); MS_EXCEPTION_IF_NULL(scale); MS_EXCEPTION_IF_NULL(variable); @@ -76,11 +79,12 @@ const AnfNodePtr ApplyMomentumWeightDecayScaleFusion::Process(const FuncGraphPtr MS_EXCEPTION_IF_NULL(learning_rate); MS_EXCEPTION_IF_NULL(gradient); MS_EXCEPTION_IF_NULL(momentum); + MS_EXCEPTION_IF_NULL(monad_state); auto prim = std::make_shared(kFusedWeightScaleApplyMomentum); MS_EXCEPTION_IF_NULL(prim); - std::vector inputs = {NewValueNode(prim), weight_decay, scale, variable, - accumulation, learning_rate, gradient, momentum}; + std::vector inputs = {NewValueNode(prim), weight_decay, scale, variable, accumulation, + learning_rate, gradient, momentum, monad_state}; auto replace_node = graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(replace_node); auto types = {AnfAlgo::GetOutputInferDataType(node, 0)}; diff --git a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h index c1b92c8242..e0f8ea555e 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h +++ b/mindspore/ccsrc/backend/optimizer/gpu/apply_momentum_weight_scale_fusion.h @@ -25,6 +25,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { public: explicit ApplyMomentumWeightDecayScaleFusion(bool multigraph = true) : PatternProcessPass("momentum_weightdecay_scale_fusion", multigraph) { + monad_ = std::make_shared(); weight_decay_ = std::make_shared(); scale_ = std::make_shared(IsScalar); variable_ = std::make_shared(); @@ -32,6 +33,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { learning_rate_ = std::make_shared(); gradient_ = std::make_shared(); momentum_ = std::make_shared(); + monad_state_ = std::make_shared(); } ~ApplyMomentumWeightDecayScaleFusion() override = default; const BaseRef DefinePattern() const override; @@ -40,6 +42,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { private: static bool IsScalar(const BaseRef &n); + VarPtr monad_; VarPtr weight_decay_; VarPtr scale_; VarPtr variable_; @@ -47,6 +50,7 @@ class ApplyMomentumWeightDecayScaleFusion : public PatternProcessPass { VarPtr learning_rate_; VarPtr gradient_; VarPtr momentum_; + VarPtr monad_state_; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/combine_cast_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/combine_cast_fusion.cc index e08016e9e7..f62d4d05c6 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/combine_cast_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/combine_cast_fusion.cc @@ -37,11 +37,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector for (size_t idx = 0; idx < node_list.size(); ++idx) { auto cnode = utils::cast(node_list[idx]); MS_EXCEPTION_IF_NULL(cnode); - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_device_format.push_back(kOpFormat_DEFAULT); inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index)); } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); + for (size_t output_index = 0; output_index < output_num; ++output_index) { outputs_device_format.push_back(kOpFormat_DEFAULT); outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index)); outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); @@ -57,16 +59,39 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector bool GetDealList(const std::vector &node_list, std::vector> *deal_list) { std::vector cast_32to16_list; std::vector cast_16to32_list; + AnfNodePtr cast_32to16_load_monad = nullptr; + AnfNodePtr cast_16to32_load_monad = nullptr; + constexpr size_t second_input_index = 2; for (auto &cast_node : node_list) { // currently, we only deal with the construct : [Param->Cast->] to avoid being a cycle. - if (cast_node != nullptr && cast_node->isa() && AnfAlgo::GetCNodeName(cast_node) == "Cast" && - (AnfAlgo::GetInputNode(utils::cast(cast_node), 0))->isa()) { - auto dst = AnfAlgo::GetOutputInferDataType(cast_node, 0); - auto src = AnfAlgo::GetPrevNodeOutputInferDataType(cast_node, 0); - if (dst == kNumberTypeFloat16 && src == kNumberTypeFloat32) { - cast_32to16_list.push_back(cast_node); - } else if (dst == kNumberTypeFloat32 && src == kNumberTypeFloat16) { - cast_16to32_list.push_back(cast_node); + // { prim::kPrimCast, { prim::kPrimLoad, Parameter, U }} + if (IsPrimitiveCNode(cast_node, prim::kPrimCast)) { + auto input0 = AnfAlgo::GetInputNode(utils::cast(cast_node), 0); + if (input0->isa() || (IsPrimitiveCNode(input0, prim::kPrimLoad) && + (AnfAlgo::GetInputNode(utils::cast(input0), 0))->isa())) { + auto dst = AnfAlgo::GetOutputInferDataType(cast_node, 0); + auto src = AnfAlgo::GetPrevNodeOutputInferDataType(cast_node, 0); + if (dst == kNumberTypeFloat16 && src == kNumberTypeFloat32) { + cast_32to16_list.push_back(cast_node); + if (IsPrimitiveCNode(input0, prim::kPrimLoad)) { + auto &monad = input0->cast()->inputs().at(second_input_index); + if (cast_32to16_load_monad == nullptr) { + cast_32to16_load_monad = monad; + } else if (cast_32to16_load_monad != monad) { + return false; + } + } + } else if (dst == kNumberTypeFloat32 && src == kNumberTypeFloat16) { + cast_16to32_list.push_back(cast_node); + if (IsPrimitiveCNode(input0, prim::kPrimLoad)) { + auto &monad = input0->cast()->inputs().at(second_input_index); + if (cast_16to32_load_monad == nullptr) { + cast_16to32_load_monad = monad; + } else if (cast_16to32_load_monad != monad) { + return false; + } + } + } } } } diff --git a/mindspore/ccsrc/backend/optimizer/gpu/combine_momentum_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/combine_momentum_fusion.cc index bd53b96565..9af6962d9e 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/combine_momentum_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/combine_momentum_fusion.cc @@ -36,11 +36,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const std::vector for (size_t idx = 0; idx < node_list.size(); ++idx) { auto cnode = utils::cast(node_list[idx]); MS_EXCEPTION_IF_NULL(cnode); - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_device_format.push_back(kOpFormat_DEFAULT); inputs_device_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_index)); } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); + for (size_t output_index = 0; output_index < output_num; ++output_index) { outputs_device_format.push_back(kOpFormat_DEFAULT); outputs_device_type.push_back(AnfAlgo::GetOutputInferDataType(cnode, output_index)); outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); diff --git a/mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.cc index 21157c9224..e2c20febad 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/cudnn_inplace_fusion.cc @@ -53,6 +53,8 @@ std::set kSkipOpNames = { std::map kAggregatesOpNames = { {kConv2DBackpropInputOpName, 0}, {kmaxPoolGradOpName, 2}, {kFusedBatchNormGradExWithAddAndActivation, 0}}; +constexpr size_t inplace_node_size = 2; + template void SetPrimAttr(AnfNodePtr inplace_node, const string &key, const T &value) { auto primitive = AnfAlgo::GetCNodePrimitive(inplace_node); @@ -60,40 +62,103 @@ void SetPrimAttr(AnfNodePtr inplace_node, const string &key, const T &value) { primitive->AddAttr(key, MakeValue(value)); } -void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector *inplace_node) { +std::pair GetCoverIndex(const std::vector &inplace_node) { + if (inplace_node.size() != inplace_node_size) { + return {0, false}; + } + auto first_node = inplace_node[0].node; + auto second_node = inplace_node[1].node; + if (AnfAlgo::GetCNodeName(first_node) != kConv2DBackpropInputOpName || + AnfAlgo::GetCNodeName(second_node) != kConv2DBackpropInputOpName) { + return {0, false}; + } + + auto first_node_prim = AnfAlgo::GetCNodePrimitive(first_node); + auto first_node_channel = first_node_prim.get()->GetAttr("out_channel"); + MS_EXCEPTION_IF_NULL(first_node_channel); + size_t first_channel = first_node_channel->cast()->value(); + auto second_node_prim = AnfAlgo::GetCNodePrimitive(second_node); + auto second_node_channel = second_node_prim.get()->GetAttr("out_channel"); + MS_EXCEPTION_IF_NULL(second_node_channel); + size_t second_channel = second_node_channel->cast()->value(); + size_t cover_index = (first_channel >= second_channel) ? 0 : 1; + return {cover_index, true}; +} + +void CopyKernelInfo(AnfNodePtr src, AnfNodePtr dst) { + auto build_info = AnfAlgo::GetSelectKernelBuildInfo(src); + AnfAlgo::SetSelectKernelBuildInfo(build_info, dst.get()); + size_t output_num = AnfAlgo::GetOutputTensorNum(src); + std::vector types; + std::vector> shapes; + for (size_t i = 0; i < output_num; i++) { + types.emplace_back(AnfAlgo::GetOutputInferDataType(src, i)); + shapes.emplace_back(AnfAlgo::GetOutputInferShape(src, i)); + } + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, dst.get()); +} + +void CheckInplaceNodeInputs(std::vector *inplace_node, const FuncGraphPtr &graph) { + if (inplace_node->size() == inplace_node_size) { + auto first_cnode = (*inplace_node)[0].node->cast(); + MS_EXCEPTION_IF_NULL(first_cnode); + auto first_node_input = first_cnode->input(1); + auto second_cnode = (*inplace_node)[1].node->cast(); + MS_EXCEPTION_IF_NULL(second_cnode); + auto second_node_input = second_cnode->input(1); + + // if two inplace nodes have same input, will be have loop after insert depend + // so copy a new input for one of inplace node + if (first_node_input == second_node_input) { + auto cnode = first_node_input->cast(); + auto new_input = graph->NewCNode(cnode->inputs()); + new_input->set_abstract(first_node_input->abstract()); + CopyKernelInfo(first_node_input, new_input); + auto new_inplace_node = graph->NewCNode({first_cnode->input(0), new_input, first_cnode->input(2)}); + new_inplace_node->set_abstract(first_cnode->abstract()); + CopyKernelInfo(first_cnode, new_inplace_node); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->Replace(first_cnode, new_inplace_node); + (*inplace_node)[0].node = new_inplace_node; + } + } +} + +void SetNodeAttr(AnfNodeIndex aggregate_node, AnfNodePtr skip_node, std::vector *inplace_node, + const FuncGraphPtr &graph) { SetPrimAttr(aggregate_node.node, "aggregate", true); SetPrimAttr(aggregate_node.node, "aggregate_input_index", aggregate_node.index); SetPrimAttr(skip_node, "skip", true); static uint32_t group = 0; + auto [cover_index, order_required] = GetCoverIndex(*inplace_node); + if (order_required) { + CheckInplaceNodeInputs(inplace_node, graph); + } for (size_t i = 0; i < inplace_node->size(); i++) { - auto algo = (i == 0) ? "cover" : "accumulation"; - SetPrimAttr((*inplace_node)[i].node, "inplace_algo", algo); - SetPrimAttr((*inplace_node)[i].node, "inplace_group", group); - SetPrimAttr((*inplace_node)[i].node, "inplace_output_index", (*inplace_node)[i].index); + auto algo = (i == cover_index) ? "cover" : "accumulation"; + auto node = (*inplace_node)[i].node; + SetPrimAttr(node, "inplace_algo", algo); + SetPrimAttr(node, "inplace_group", group); + SetPrimAttr(node, "inplace_output_index", (*inplace_node)[i].index); + // for Conv2DBackpropInputOp, need insert depend node to keep order, set the larger channel to cover + if (order_required && i != cover_index) { + auto acc_node = node; + auto cover_node = (*inplace_node)[cover_index].node; + auto acc_node_input = acc_node->cast()->input(1); + std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), + acc_node_input, cover_node}; + auto depend_node = graph->NewCNode(inputs); + depend_node->set_abstract(acc_node_input->abstract()); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->Replace(acc_node_input, depend_node); + } } group++; } -void InsertControlDependToGraph(const FuncGraphPtr &graph, const std::vector &inplace_nodes, - const AnfNodePtr aggregate_node) { - std::vector inputs1 = {NewValueNode(std::make_shared(prim::kPrimControlDepend->name())), - inplace_nodes[0].node, inplace_nodes[1].node}; - auto control_depend_node = graph->NewCNode(inputs1); - - std::vector inputs2 = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), - aggregate_node, control_depend_node}; - auto depend_node = graph->NewCNode(inputs2); - - auto users = GetRealNodeUsedList(graph, aggregate_node); - if (users->size() == 0) { - MS_LOG(EXCEPTION) << "No users found: " << aggregate_node->DebugString(); - } - auto mount_node = users->at(0).first->cast(); - MS_EXCEPTION_IF_NULL(mount_node); - mount_node->set_input(kFirstDataInputIndex, depend_node); -} - bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeIndex *aggregate, AnfNodePtr *skip_node, std::vector *inplace) { MS_EXCEPTION_IF_NULL(skip_node); @@ -117,7 +182,8 @@ bool PatternMatch(const FuncGraphPtr &graph, const AnfNodePtr &node, AnfNodeInde auto cnode = (*skip_node)->cast(); MS_EXCEPTION_IF_NULL(cnode); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode); i++) { + size_t input_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t i = 0; i < input_num; i++) { auto inplace_node = AnfAlgo::GetInputNode(utils::cast(*skip_node), i); if (!inplace_node->isa()) { return false; @@ -187,9 +253,7 @@ bool CudnnInplaceAggregate::Run(const FuncGraphPtr &graph) { << "; inplace node 1: " << inplace_node[1].index << ", " << inplace_node[1].node->DebugString() << std::endl; // 2. Set Node attr - SetNodeAttr(aggregate_node, skip_node, &inplace_node); - // 3. Set dependence for inplace nodes - InsertControlDependToGraph(graph, inplace_node, aggregate_node.node); + SetNodeAttr(aggregate_node, skip_node, &inplace_node, graph); } return true; diff --git a/mindspore/ccsrc/backend/optimizer/gpu/post_batch_norm_add_relu_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/post_batch_norm_add_relu_fusion.cc new file mode 100644 index 0000000000..516bf2c8d2 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/post_batch_norm_add_relu_fusion.cc @@ -0,0 +1,97 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/gpu/post_batch_norm_add_relu_fusion.h" + +#include +#include +#include + +#include "backend/session/anf_runtime_algorithm.h" +#include "ir/primitive.h" +#include "utils/utils.h" +#include "backend/optimizer/common/helper.h" +#include "runtime/device/gpu/kernel_info_setter.h" + +namespace mindspore { +namespace opt { +const BaseRef PostBatchNormAddReluFusion::DefinePattern() const { + VectorRef batch_norm_ex = VectorRef({prim::kPrimFusedBatchNormEx, x_, scale_, bias_, mean_, var_}); + VectorRef tuple_get_item = VectorRef({prim::kPrimTupleGetItem, batch_norm_ex, index_}); + VectorRef tensor_add = VectorRef({prim::kPrimAdd, z_, tuple_get_item}); + VectorRef relu = VectorRef({prim::kPrimRelu, tensor_add}); + return relu; +} + +const AnfNodePtr PostBatchNormAddReluFusion::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(node); + + auto tensor_add = AnfAlgo::GetInputNode(utils::cast(node), 0); + MS_EXCEPTION_IF_NULL(tensor_add); + auto tuple_get_item = AnfAlgo::GetInputNode(utils::cast(tensor_add), 1); + MS_EXCEPTION_IF_NULL(tuple_get_item); + auto batch_norm_ex = AnfAlgo::GetInputNode(utils::cast(tuple_get_item), 0); + MS_EXCEPTION_IF_NULL(batch_norm_ex); + auto format_attr = AnfAlgo::GetCNodePrimitive(batch_norm_ex)->GetAttr("format"); + MS_EXCEPTION_IF_NULL(format_attr); + auto format = GetValue(format_attr); + if (AnfAlgo::GetInputFormat(batch_norm_ex, 0) != kOpFormat_NHWC && format != "NHWC") { + return nullptr; + } + auto shape = AnfAlgo::GetInputDeviceShape(batch_norm_ex, 0); + if (shape.back() % kBNChannelMultipleFactor != 0) { + return nullptr; + } + + auto x = AnfAlgo::GetInputNode(utils::cast(batch_norm_ex), 0); + auto scale = AnfAlgo::GetInputNode(utils::cast(batch_norm_ex), 1); + auto bias = AnfAlgo::GetInputNode(utils::cast(batch_norm_ex), 2); + auto mean = AnfAlgo::GetInputNode(utils::cast(batch_norm_ex), 3); + auto var = AnfAlgo::GetInputNode(utils::cast(batch_norm_ex), 4); + auto z = AnfAlgo::GetInputNode(utils::cast(tensor_add), 0); + + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(scale); + MS_EXCEPTION_IF_NULL(bias); + MS_EXCEPTION_IF_NULL(mean); + MS_EXCEPTION_IF_NULL(var); + MS_EXCEPTION_IF_NULL(z); + + auto prim = std::make_shared(kFusedBatchNormExWithAddAndActivation); + MS_EXCEPTION_IF_NULL(prim); + std::vector inputs = {NewValueNode(prim), x, scale, bias, mean, var, z}; + auto fused_batch_norm_with_add_relu = graph->NewCNode(inputs); + MS_EXCEPTION_IF_NULL(fused_batch_norm_with_add_relu); + + std::vector outputs_type; + std::vector> outputs_shape; + auto output_num = AnfAlgo::GetOutputTensorNum(batch_norm_ex); + for (size_t i = 0; i < output_num; i++) { + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(batch_norm_ex, i)); + outputs_shape.push_back(AnfAlgo::GetOutputInferShape(batch_norm_ex, i)); + } + AnfAlgo::SetOutputInferTypeAndShape(outputs_type, outputs_shape, fused_batch_norm_with_add_relu.get()); + AnfAlgo::CopyNodeAttrs(batch_norm_ex, fused_batch_norm_with_add_relu); + + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + manager->Replace(batch_norm_ex, fused_batch_norm_with_add_relu); + device::gpu::SetKernelInfo(fused_batch_norm_with_add_relu); + return tuple_get_item; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/gpu/post_batch_norm_add_relu_fusion.h b/mindspore/ccsrc/backend/optimizer/gpu/post_batch_norm_add_relu_fusion.h new file mode 100644 index 0000000000..8d2e4ecd32 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/gpu/post_batch_norm_add_relu_fusion.h @@ -0,0 +1,51 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_POST_BATCH_NORM_ADD_RELU_FUSION_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_POST_BATCH_NORM_ADD_RELU_FUSION_H_ + +#include +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class PostBatchNormAddReluFusion : public PatternProcessPass { + public: + explicit PostBatchNormAddReluFusion(bool multigraph = true) + : PatternProcessPass("post_batch_norm_add_relu_fusion", multigraph) { + x_ = std::make_shared(); + scale_ = std::make_shared(); + bias_ = std::make_shared(); + mean_ = std::make_shared(); + var_ = std::make_shared(); + index_ = std::make_shared(); + z_ = std::make_shared(); + } + ~PostBatchNormAddReluFusion() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr x_; + VarPtr scale_; + VarPtr bias_; + VarPtr mean_; + VarPtr var_; + VarPtr index_; + VarPtr z_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GPU_POST_BATCH_NORM_ADD_RELU_FUSION_H_ diff --git a/mindspore/ccsrc/backend/optimizer/gpu/relu_v2_pass.cc b/mindspore/ccsrc/backend/optimizer/gpu/relu_v2_pass.cc index c0094d827a..7c8160c144 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/relu_v2_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/relu_v2_pass.cc @@ -32,9 +32,7 @@ const size_t kReluV2OutputNum = 2; CNodePtr GetRelu(const CNodePtr &relu_grad) { MS_EXCEPTION_IF_NULL(relu_grad); - if (relu_grad->size() != kReluGradInputNum) { - MS_LOG_EXCEPTION << "ReluGrad has wrong input size " << relu_grad->size(); - } + CheckCNodeInputSize(relu_grad, kReluGradInputTensorNum); auto relu_anf = relu_grad->input(2); MS_EXCEPTION_IF_NULL(relu_anf); return relu_anf->cast(); @@ -47,11 +45,13 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { std::vector outputs_type; kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(node); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); inputs_format.push_back(kOpFormat_DEFAULT); } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(node); ++output_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { outputs_type.push_back(AnfAlgo::GetOutputInferDataType(node, output_index)); outputs_format.push_back(kOpFormat_DEFAULT); } @@ -65,9 +65,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(CNodePtr node) { CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(relu); - if (relu->size() != kReluInputNum) { - MS_LOG_EXCEPTION << "Relu has wrong input size " << relu->size(); - } + CheckCNodeInputSize(relu, kReluInputTensorNum); auto prim = std::make_shared(kReluV2OpName); std::vector inputs = {NewValueNode(prim), relu->input(1)}; @@ -106,7 +104,8 @@ CNodePtr CreateReluGradV2(const FuncGraphPtr &graph, const CNodePtr &relu_grad, std::vector types; std::vector> shapes; - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(relu_grad); i++) { + size_t output_num = AnfAlgo::GetOutputTensorNum(relu_grad); + for (size_t i = 0; i < output_num; i++) { types.push_back(AnfAlgo::GetOutputInferDataType(relu_grad, i)); shapes.push_back(AnfAlgo::GetOutputInferShape(relu_grad, i)); } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.cc index d81a4120e1..2ca56fd434 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.cc @@ -305,52 +305,14 @@ void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNo user_cnode->set_input(index, depend_cnode); } -AnfNodePtr AtomicCleanInsertter::AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, - const AnfNodePtr &behind_node, const AnfNodePtr &patron_node) { - // Create control depend, first input is composite op, second is user - AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), prior_node, behind_node}; - auto control_depend_cnode = main_graph->NewCNode(cd_inputs); - main_graph->AddNode(control_depend_cnode); - - // Create depend node to hold new control depend node. - AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), patron_node, control_depend_cnode}; - auto depend_cnode = main_graph->NewCNode(d_inputs); - depend_cnode->set_abstract(patron_node->abstract()); - main_graph->AddNode(depend_cnode); - - return depend_cnode; -} - -std::tuple AtomicCleanInsertter::FindPatronNode(const KernelGraphPtr &main_graph) { - auto mng = main_graph->manager(); - if (mng == nullptr) { - mng = Manage(main_graph, true); - main_graph->set_manager(mng); - } - - AnfNodePtr patron_node; - - auto return_cnode = main_graph->get_return()->cast(); - MS_EXCEPTION_IF_NULL(return_cnode); - auto output_node = return_cnode->input(kFirstDataInputIndex); - if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) { - auto output_cnode = output_node->cast(); - MS_EXCEPTION_IF_NULL(output_cnode); - patron_node = output_cnode->input(kFirstDataInputIndex); - } else { - patron_node = output_node; - } - - auto &user_nodes = mng->node_users()[patron_node]; - auto user = user_nodes.begin(); - return std::make_tuple(patron_node, user->first, user->second); -} - -void AtomicCleanInsertter::PostprocessForLastPatron(const AnfNodePtr &patron_node, const AnfNodePtr &patron_user, - int index) { - auto patron_user_cnode = patron_user->cast(); - MS_EXCEPTION_IF_NULL(patron_user_cnode); - patron_user_cnode->set_input(index, patron_node); +CNodePtr AtomicCleanInsertter::InsertUpdateState(const KernelGraphPtr &main_graph, const CNodePtr &composite_node) { + // Insert update_state_node, need mount a monad node. + auto u = NewValueNode(kUMonad); + u->set_abstract(kUMonad->ToAbstract()); + AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, composite_node}; + auto update_state_cnode = main_graph->NewCNode(update_state_inputs); + main_graph->AddNode(update_state_cnode); + return update_state_cnode; } CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type) { @@ -474,24 +436,21 @@ std::vector > AtomicCleanInsertter::FindOriginCNodeUs } void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, - const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng) { + const AnfNodePtr &broadcast_to_node, + const AnfNodePtr &update_state_node, const FuncGraphManagerPtr &mng) { // 1. find users, change getitem index if needed. std::vector > reduce_user_nodes = FindOriginCNodeUsers(main_graph, composite_node, mng, true); for (const auto &[user_node, index] : reduce_user_nodes) { - // 2. set ac output as user's input. - // 3. Make sure modified composite node running first. - // * To not change the origin node's dependency relation, add ControlDepend and Depend node. - // * For Return node and output node, ControlDepend node will change the order of these two node, which will may - // main graph running failed. So only add Depend node to meet the need of execute order. - if (IsPrimitiveCNode(user_node, prim::kPrimReturn) || user_node == main_graph->output()) { - AddDepend(main_graph, broadcast_to_node, composite_node, user_node, index); - } else { - auto user_cnode = user_node->cast(); - MS_EXCEPTION_IF_NULL(user_cnode); - user_cnode->set_input(index, broadcast_to_node); - to_process_order_.emplace_back(composite_node, user_node); - } + // 2. Make sure modified composite node running first, So firstly, create load_node, then add edge to connect + // update_state_node, broadcat_node and load_node to keep order. + AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), broadcast_to_node, update_state_node}; + auto load_node = main_graph->NewCNode(load_inputs); + main_graph->AddNode(load_node); + auto user_cnode = user_node->cast(); + MS_EXCEPTION_IF_NULL(user_cnode); + user_cnode->set_input(index, load_node); + to_process_order_.emplace_back(composite_node, user_node); } } @@ -509,8 +468,11 @@ void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, c // Note: if it's single output, this will increase total memory because of a fake out. ProcessOriginCNode(origin_composite_node, broadcast_to_node, mng); - // Replace origin ReduceSum's user with atomic clean output, and add control depend from composite op to user. - ProcessOriginCNodeUser(main_graph, origin_composite_node, broadcast_to_node, mng); + // Insert update_state_node to keep execution order. + auto update_state_node = InsertUpdateState(main_graph, origin_composite_node); + + // Replace origin ReduceSum's user with atomic clean output + ProcessOriginCNodeUser(main_graph, origin_composite_node, broadcast_to_node, update_state_node, mng); MS_LOG(INFO) << "Target node: " << origin_composite_node->fullname_with_scope() << ", clean node: " << broadcast_to_node->fullname_with_scope(); } @@ -554,14 +516,6 @@ bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) { } if (changed) { - if (!to_process_order_.empty()) { - auto [patron_node, patron_user, user_index] = FindPatronNode(kernel_graph); - for (const auto &[prior, behind] : to_process_order_) { - patron_node = AddControlDepend(kernel_graph, prior, behind, patron_node); - } - PostprocessForLastPatron(patron_node, patron_user, user_index); - } - mng->RemoveRoots(); mng->KeepRoots({func_graph}); } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.h index e6dffde93e..be137919e8 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/add_atomic_clean_gpu.h @@ -37,9 +37,10 @@ class AtomicCleanInsertter : public Pass { virtual void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input); virtual void ProcessOriginCNode(const AnfNodePtr &composite_node, const AnfNodePtr &new_input, const FuncGraphManagerPtr &mng); - void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng); void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node, const AnfNodePtr &user_node, int index); + void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng); + CNodePtr InsertUpdateState(const KernelGraphPtr &main_graph, const CNodePtr &composite_node); CNodePtr atomic_add_node_{nullptr}; private: @@ -48,11 +49,8 @@ class AtomicCleanInsertter : public Pass { CNodePtr CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type); void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter); void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, - const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng); - std::tuple FindPatronNode(const KernelGraphPtr &main_graph); - AnfNodePtr AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node, - const AnfNodePtr &behind_node, const AnfNodePtr &patron_node); - void PostprocessForLastPatron(const AnfNodePtr &patron_node, const AnfNodePtr &patron_user, int index); + const AnfNodePtr &broadcast_to_node, const AnfNodePtr &update_state_node, + const FuncGraphManagerPtr &mng); std::vector> FindOriginCNodeUsers(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node, const FuncGraphManagerPtr &mng, bool correct_index); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc index f1057685d6..e8e2295f6e 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/basic_ops_fusion.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -149,9 +149,18 @@ bool FuseBasicOps(const FuncGraphPtr &kernel_graph, const std::vectorinsert(fuse_nodes.begin(), fuse_nodes.end()); AnfNodePtr fused_new_node; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/depend_formater.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/depend_formater.cc index ef3429e451..10a778ae5a 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/depend_formater.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/depend_formater.cc @@ -109,7 +109,10 @@ bool DependFormater::Run(const FuncGraphPtr &func_graph) { // 1. Try to remove redundant depend. bool changed = false; auto nodes = TopoSort(func_graph->get_return()); - std::for_each(nodes.rbegin(), nodes.rend(), [&changed, &mng](const AnfNodePtr &node) { + std::for_each(nodes.rbegin(), nodes.rend(), [&changed, &mng](const AnfNodePtr &node) -> void { + if (HasAbstractMonad(node)) { + return; + } if (RemoveRedundantDepend(node, mng)) { changed = true; } @@ -126,7 +129,8 @@ bool DependFormater::Run(const FuncGraphPtr &func_graph) { // Find depend and its free nodes. for (const auto &node : nodes) { - if (!IsPrimitiveCNode(node, prim::kPrimDepend)) { + if (!IsPrimitiveCNode(node, prim::kPrimDepend) || + HasAbstractMonad(node->cast()->input(kDependAttachNodeIndex))) { continue; } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index 26e1237dd9..3a6b949f80 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -177,6 +177,7 @@ bool GraphKernelExpander::Run(const FuncGraphPtr &func_graph) { std::shared_ptr pass = std::make_shared(); pass->Run(func_graph); } + auto mng = func_graph->manager(); if (mng == nullptr) { mng = Manage(func_graph, true); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index ba90ef88e7..f2fa804431 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -494,8 +494,8 @@ std::vector GetFusibleOpList() { prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater, - prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, - prim::kPrimCast, prim::kPrimExpandDims}; + prim::kPrimCast, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, + prim::kPrimAssign, prim::kPrimExpandDims}; #else std::vector fusible_basic_ops; #endif diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc index c9760955ea..070cf17b4a 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_splitter.cc @@ -629,7 +629,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer { } GetValidKernelNodes(); // call CostModel to get a split plan. - if (!SplitByCostModel() || split_plan_.size() != need_inline_.size()) { + if (!SplitByCostModel() || split_plan_.size() != need_inline_.size() || split_plan_.empty()) { split_plan_.clear(); need_inline_.clear(); return; diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_assign.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_assign.cc index 0789ad4c39..875bf6701c 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_assign.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/optimize_assign.cc @@ -103,28 +103,23 @@ bool HasPathToParamUser(const AnfNodePtr &gk_node, const AnfNodePtr ¶m_user) return result; } -AnfNodePtr AddControlDepend(const FuncGraphPtr &func_graph, const AnfNodePtr &getitem, const AnfNodePtr ¶m_user) { - auto mng = func_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), getitem, param_user}; - auto cd_node = func_graph->NewCNode(cd_inputs); - func_graph->AddNode(cd_node); - return cd_node; -} - -void LinkControlDepends(const FuncGraphPtr &func_graph, const AnfNodePtrList &cd_nodes) { - auto mng = func_graph->manager(); - MS_EXCEPTION_IF_NULL(mng); - auto output_tuple = func_graph->output()->cast(); - MS_EXCEPTION_IF_NULL(output_tuple); - auto cur_node = output_tuple->input(1); - for (const auto &cd : cd_nodes) { - AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), cur_node, cd}; - auto depend_node = func_graph->NewCNode(depend_inputs); - depend_node->set_abstract(depend_inputs[1]->abstract()); - cur_node = depend_node; - } - mng->Replace(output_tuple->input(1), cur_node); +void KeepExecOrder(const FuncGraphPtr &func_graph, const AnfNodePtr &gk_node, const AnfNodePtr &par_user_node, + const FuncGraphManagerPtr &mng) { + // Insert update_state_node, need mount a monad node. + auto u = NewValueNode(kUMonad); + u->set_abstract(kUMonad->ToAbstract()); + AnfNodePtrList update_state_inputs = {NewValueNode(prim::kPrimUpdateState), u, gk_node}; + auto update_state_node = func_graph->NewCNode(update_state_inputs); + update_state_node->set_abstract(gk_node->abstract()); + func_graph->AddNode(update_state_node); + + // Insert load_node + AnfNodePtrList load_inputs = {NewValueNode(prim::kPrimLoad), par_user_node, update_state_node}; + auto load_node = func_graph->NewCNode(load_inputs); + load_node->set_abstract(par_user_node->abstract()); + func_graph->AddNode(load_node); + + mng->Replace(gk_node, par_user_node); } int64_t GetitemIndex(const AnfNodePtr &getitem) { @@ -133,11 +128,10 @@ int64_t GetitemIndex(const AnfNodePtr &getitem) { return GetValue(value_ptr); } -AnfNodePtrList UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const AnfNodePtr &cnode, - const AnfNodePtr &assign_to, int64_t removed_index) { +void UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const AnfNodePtr &cnode, const AnfNodePtr &assign_to, + int64_t removed_index) { auto mng = func_graph->manager(); MS_EXCEPTION_IF_NULL(mng); - AnfNodePtrList cd_nodes; for (const auto &getitem_iter : mng->node_users()[cnode]) { auto getitem = getitem_iter.first; if (GetitemIndex(getitem) != removed_index) continue; @@ -152,13 +146,10 @@ AnfNodePtrList UpdateUsersOfGraphKernel(const FuncGraphPtr &func_graph, const An if (!AnfAlgo::IsRealKernel(getitem_user) || HasPathToParamUser(cnode, getitem_user)) { continue; } - // keep execution order: cnode -> getitem_user - auto cd_node = AddControlDepend(func_graph, getitem, getitem_user); - cd_nodes.push_back(cd_node); + KeepExecOrder(func_graph, cnode, getitem_user, mng); } break; } - return cd_nodes; } bool RepalceOutputByParameter(const FuncGraphPtr &func_graph) { @@ -166,7 +157,6 @@ bool RepalceOutputByParameter(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); bool changed = false; - AnfNodePtrList control_depend_nodes; for (const auto &n : todos) { if (!AnfAlgo::IsGraphKernel(n)) continue; auto cnode = n->cast(); @@ -174,11 +164,9 @@ bool RepalceOutputByParameter(const FuncGraphPtr &func_graph) { if (replaceable_nodes.empty()) continue; changed = true; for (const auto &iter : replaceable_nodes) { - auto cd_nodes = UpdateUsersOfGraphKernel(func_graph, cnode, iter.second, iter.first); - control_depend_nodes.insert(control_depend_nodes.end(), cd_nodes.begin(), cd_nodes.end()); + UpdateUsersOfGraphKernel(func_graph, cnode, iter.second, iter.first); } } - LinkControlDepends(func_graph, control_depend_nodes); return changed; } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc index 150927f94f..bcd94f1c78 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/parallel_fusion.cc @@ -97,7 +97,8 @@ void ProcessThroughPassCNode(std::function pass_fn, void ProcessDependCNode(OrderedMap *node_rels) { for (auto &[node, node_rel] : (*node_rels)) { - if (!IsPrimitiveCNode(node, prim::kPrimDepend)) { + if (!IsPrimitiveCNode(node, prim::kPrimDepend) || + HasAbstractMonad(node->cast()->input(kDependAttachNodeIndex))) { continue; } @@ -118,96 +119,6 @@ void ProcessDependCNode(OrderedMap *node_rels) { ProcessThroughPassCNode([](const AnfNodePtr &node) { return IsOneOf(node, {prim::kPrimDepend}); }, node_rels); } -std::tuple, std::pair> FindRelationOfControlDepend( - const AnfNodePtr &node, OrderedMap *node_rels) { - auto cnode = node->cast(); - auto prior_node = cnode->input(kControlDependPriorIndex); - auto behind_node = cnode->input(kControlDependBehindIndex); - MS_EXCEPTION_IF_NULL(prior_node); - MS_EXCEPTION_IF_NULL(behind_node); - - OrderedSet prior_nodes; - prior_nodes.insert(prior_node); - OrderedSet behind_nodes; - behind_nodes.insert(behind_node); - - int64_t depend_mode = 0; - if (AnfAlgo::HasNodeAttr(kControlDependMode, cnode)) { - depend_mode = AnfAlgo::GetNodeAttr(cnode, kControlDependMode); - } - if (prior_node->isa() && depend_mode == 1) { - prior_nodes = (*node_rels)[prior_node].nexts; - } - if (behind_node->isa()) { - behind_nodes = depend_mode == 1 ? (*node_rels)[behind_node].nexts : OrderedSet(); - } - - // Get real nodes. - AnfNodePtrList real_prior_nodes; - std::set prior_visited; - for (const auto &tmp : prior_nodes) { - AnfAlgo::GetAllFatherRealNode(tmp, &real_prior_nodes, &prior_visited); - } - AnfNodePtrList real_behind_nodes; - std::set behind_visited; - for (const auto &tmp : behind_nodes) { - AnfAlgo::GetAllFatherRealNode(tmp, &real_behind_nodes, &behind_visited); - } - - return std::make_tuple(std::make_pair(prior_node, behind_node), std::make_pair(real_prior_nodes, real_behind_nodes)); -} - -void ReLinkNodesOfControlDependByRelation(const std::unordered_map &control_depend_info, - OrderedMap *node_rels) { - // Relink and its log. - for (const auto &m : control_depend_info) { - const auto &prior = m.second[0]; - const auto &behind = m.second[1]; - (*node_rels)[prior].nexts.insert(behind); - (*node_rels)[behind].pres.insert(prior); - MS_LOG(DEBUG) << "Relink relation of " << m.first->fullname_with_scope() << ": " << prior->fullname_with_scope() - << " -> " << behind->fullname_with_scope(); - } -} - -void ProcessControlDependCNode(OrderedMap *node_rels) { - std::unordered_map control_depend_info; - AnfNodePtrList latter_to_be_erased; - - // Collect ControlDepend node and its input and output nodes. - for (auto &[node, node_rel] : (*node_rels)) { - if (!IsPrimitiveCNode(node, prim::kPrimControlDepend)) { - continue; - } - - auto [direct_relation, real_relations] = FindRelationOfControlDepend(node, node_rels); - auto &[prior_node, behind_node] = direct_relation; - auto &[real_prior_nodes, real_behind_nodes] = real_relations; - - (*node_rels)[prior_node].nexts.erase(node); - (*node_rels)[behind_node].nexts.erase(node); - node_rel.pres.erase(prior_node); - node_rel.pres.erase(behind_node); - - for (auto &first_node : real_prior_nodes) { - for (auto &second_node : real_behind_nodes) { - MS_EXCEPTION_IF_NULL(first_node); - MS_EXCEPTION_IF_NULL(second_node); - control_depend_info.insert({node, {first_node, second_node}}); - } - } - latter_to_be_erased.push_back(node); - } - - // Delete ControlDepend node before relink its relation. - for (const auto &node : latter_to_be_erased) { - node_rels->erase(node); - } - - // Rebuild relation between prior and behind node. - ReLinkNodesOfControlDependByRelation(control_depend_info, node_rels); -} - void ProcessTailMakeTupleCNode(OrderedMap *node_rels) { AnfNodePtrList latter_to_be_erased; for (auto &[node, node_rel] : (*node_rels)) { @@ -538,7 +449,6 @@ OrderedMap ParallelOpFusion::GenAnalysisGraph(const An } ProcessDependCNode(&node_rels); - ProcessControlDependCNode(&node_rels); ProcessThroughPassCNode( [](const AnfNodePtr &node) { return IsOneOf(node, {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimSqueeze, prim::kPrimTupleGetItem}); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/split_assign.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/split_assign.cc new file mode 100644 index 0000000000..36dade7b94 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/split_assign.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/optimizer/graph_kernel/split_assign.h" + +#include +#include +#include +#include + +#include "base/core_ops.h" +#include "utils/utils.h" +#include "runtime/device/kernel_info.h" +#include "backend/optimizer/common/helper.h" + +namespace mindspore { +namespace opt { + +const BaseRef SplitAssign::DefinePattern() const { + VarPtr Xs = std::make_shared(); + VarPtr Us = std::make_shared(); + VarPtr UMonad = std::make_shared(); + return VectorRef({prim::kPrimAssign, Xs, Us, UMonad}); +} + +const AnfNodePtr SplitAssign::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(node); + CNodePtr cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + CheckCNodeInputSize(cnode, kAssignInputTensorNum); + // Get original assign op's abstract and inputs + AbstractBasePtr original_abstract = cnode->abstract()->Clone(); + auto original_inputs = cnode->inputs(); + // Create depend node + AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), original_inputs[1], original_inputs[3]}; + auto depend_cnode = func_graph->NewCNode(depend_inputs); + depend_cnode->set_abstract(original_inputs[1]->abstract()); + depend_cnode->set_kernel_info(std::make_shared()); + // Create new assign node, delete U from inputs. + AnfNodePtrList new_assign_inputs = {NewValueNode(prim::kPrimAssign), depend_cnode, original_inputs[2]}; + auto new_assign_cnode = func_graph->NewCNode(new_assign_inputs); + new_assign_cnode->set_abstract(original_abstract); + new_assign_cnode->set_kernel_info(cnode->kernel_info_ptr()); + return new_assign_cnode; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/split_assign.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/split_assign.h new file mode 100644 index 0000000000..176f0e5b01 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/split_assign.h @@ -0,0 +1,32 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_ + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class SplitAssign : public PatternProcessPass { + public: + explicit SplitAssign(bool multigraph = true) : PatternProcessPass("split_assign", multigraph) {} + ~SplitAssign() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SPLIT_ASSIGN_H_ diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc index d7164b9034..d5d8d31bfd 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/substitute_dropout.cc @@ -41,13 +41,15 @@ const BaseRef SubstituteDropout::DefinePattern() const { void SetNewKernelInfo(const CNodePtr &kernel_node) { std::vector inputs_format; std::vector inputs_type; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, input_index)); inputs_type.push_back(AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, input_index)); } std::vector outputs_format; std::vector outputs_type; - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { outputs_format.emplace_back(AnfAlgo::GetPrevNodeOutputFormat(kernel_node, output_index)); outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); } @@ -69,15 +71,13 @@ const AnfNodePtr SubstituteDropout::Process(const FuncGraphPtr &func_graph, cons MS_EXCEPTION_IF_NULL(node); CNodePtr cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - if (cnode->inputs().size() < kDropoutInputNum) { - MS_LOG(EXCEPTION) << "Dropout's input num is wrong"; - } + CheckCNodeInputSize(cnode, kDropoutInputTensorNum); AbstractBasePtr old_abstract = cnode->abstract()->Clone(); auto shape = AnfAlgo::GetInputDeviceShape(cnode, 0); ShapeVector shape_i64; std::transform(shape.begin(), shape.end(), std::back_inserter(shape_i64), [](size_t x) { return SizeToLong(x); }); - // The primitive should use a clone, otherwise the attr seed will be overrode. + // The primitive should use a clone, otherwise the attr seed will be overridden. AnfNodePtrList uniform_input = {NewValueNode(prim::kPrimCudnnUniformReal->Clone())}; auto tensor = std::make_shared(kNumberTypeInt64, ShapeVector(1, SizeToLong(shape.size())), static_cast(&shape[0]), kNumberTypeInt64); diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc index b20a6f42fb..400a0fc808 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse.cc @@ -249,7 +249,8 @@ KernelRefCountPtr MemReuseUtil::GetRef(const AnfNodePtr &node, int output_idx) { if (node == nullptr) { MS_LOG(EXCEPTION) << "The node pointer is a nullptr."; } - if (node->isa()) { + // Get ref count for cnode, except monad cnode. + if (node->isa() && !HasAbstractMonad(node)) { auto ak_node = node->cast(); auto key = ak_node.get(); MemReuseChecker::GetInstance().CheckOutRef(kernel_output_refs_, ak_node, IntToSize(output_idx)); @@ -314,7 +315,8 @@ void MemReuseUtil::SetKernelDefInputs() { MS_LOG(EXCEPTION) << "kernel [" << kernel->fullname_with_scope() << "] is not init."; } auto kernel_def = iter->second; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel); + for (size_t i = 0; i < input_num; ++i) { auto ref_ptr = GetKernelInputRef(kernel, i); if (ref_ptr != nullptr) { // set the inputs of this kernel_def diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc index 5d5b46f998..b734e9cf46 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_reuse_checker.cc @@ -214,7 +214,8 @@ bool MemReuseChecker::CheckGraphOutputAssigned(const session::KernelGraph *graph // set real graph output node to be special who's refcount equal kMaxRefCount for (const auto &output : graph->outputs()) { MS_EXCEPTION_IF_NULL(output); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(output); ++i) { + size_t input_num = AnfAlgo::GetInputTensorNum(output); + for (size_t i = 0; i < input_num; ++i) { if (output->isa()) { auto cnode = output->cast(); auto input_node = cnode->input(i + 1); @@ -364,7 +365,8 @@ void MemReuseChecker::CheckNormalIR(const session::KernelGraph *graph) { const auto &cnodes = graph->execution_order(); for (const auto &node : cnodes) { std::vector curr_ous; - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); ++i) { + size_t output_num = AnfAlgo::GetOutputTensorNum(node); + for (size_t i = 0; i < output_num; ++i) { auto it = AnfAlgo::GetOutputAddr(node, i); MS_EXCEPTION_IF_NULL(it); auto ptr = it->GetPtr(); @@ -374,7 +376,8 @@ void MemReuseChecker::CheckNormalIR(const session::KernelGraph *graph) { } (void)node_ous_.insert(std::make_pair(node.get(), curr_ous)); std::vector curr_ins; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { + size_t input_num = AnfAlgo::GetInputTensorNum(node); + for (size_t i = 0; i < input_num; ++i) { if (i + 1 >= node->inputs().size()) { MS_LOG(EXCEPTION) << "Input index: " << i << " is larger than input number: " << AnfAlgo::GetInputTensorNum(node); diff --git a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc index 557edbaa68..154e46abe2 100644 --- a/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc +++ b/mindspore/ccsrc/backend/optimizer/mem_reuse/mem_swap_manager.cc @@ -37,7 +37,8 @@ bool MemSwapManager::Init(const mindspore::session::KernelGraph *kernel_graph, s MS_EXCEPTION_IF_NULL(kernel_mod); auto output_sizes = kernel_mod->GetOutputSizeList(); - for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(kernel); ++output_idx) { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel); + for (size_t output_idx = 0; output_idx < output_num; ++output_idx) { TensorInfo tensor_info = {output_sizes[output_idx], kernel, output_idx}; ordered_tensors_.push_back(tensor_info); } diff --git a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc index 883cdfec7a..88ea225322 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/communication_op_fusion.cc @@ -51,12 +51,14 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const CommunicationOpInfo &co rank_size = AnfAlgo::GetNodeAttr(cnode, kAttrRankSize); } MS_EXCEPTION_IF_NULL(cnode); - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); } for (size_t rank_index = 0; rank_index < IntToSize(rank_size); ++rank_index) { - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); + for (size_t output_index = 0; output_index < output_num; ++output_index) { outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); std::vector shape = AnfAlgo::GetOutputInferShape(cnode, output_index); @@ -170,6 +172,117 @@ bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communic return CheckSegments(segments, communication_op_node_size, segment_index); } +// Hard coded Load(%paraxxx, cnode()) to Load(%paraxxx, U) to prevent +// cycle after AllReduce fused. It's a workaround. +// case 1: +// cnode_load = Load(%para2, cnode_u) +// %100 = UpdateState(cnode_u, cnode_load) +// ... +// %109 = AssignAdd(%para485, Tensor(34), %100) +// %110 = UpdateState(%100, xxx) +// will convert to: +// cnode_load = Load(%para2, U) +// ... +// %109 = AssignAdd(%para485, Tensor(34), cnode_u) +// %110 = UpdateState(cnode_u, xxx) +// +// case 2: +// cnode_load = Load(%para2, cnode_u) +// %99 = make_tuple(yyy, ..., cnode_load, ...) +// %100 = UpdateState(cnode_u, %99) +// ... +// %109 = AssignAdd(%para485, Tensor(34), %100) +// %110 = UpdateState(%100, xxx) +// will convert to: +// cnode_load = Load(%para2, U) +// %99 = make_tuple(yyy, ...) +// %100 = UpdateState(cnode_u, %99) +// ... +// %109 = AssignAdd(%para485, Tensor(34), %100) +// %110 = UpdateState(%100, xxx) +// +// case 3: +// cnode_load = Load(%para2, cnode_u) +// %99 = make_tuple(cnode_load) +// %100 = UpdateState(cnode_u, %99) +// ... +// %109 = AssignAdd(%para485, Tensor(34), %100) +// %110 = UpdateState(%100, xxx) +// will convert to: +// cnode_load = Load(%para2, U) +// ... +// %109 = AssignAdd(%para485, Tensor(34), cnode_u) +// %110 = UpdateState(cnode_u, xxx) +static void AdjustAllReduceInputWithLoad(const CNodePtr &cnode) { + auto cnode_load = BroadFirstSearchFirstOf({cnode}, [](const CNodePtr &search_cnode) { + if (!IsPrimitiveCNode(search_cnode, prim::kPrimLoad)) { + return false; + } + if (search_cnode->inputs().size() != 3) { + MS_LOG(EXCEPTION) << "Load CNode should have 3 inputs, but: " << search_cnode->DebugString(); + } + return search_cnode->input(2)->isa(); + }); + if (cnode_load != nullptr) { + const auto &const_u_monad = NewValueNode(kUMonad); + const auto &cnode_u = cnode_load->input(2); + MS_LOG(DEBUG) << "Replace Load with CNode U to constant U for cnode: " << cnode_load->DebugString(); + MS_EXCEPTION_IF_NULL(cnode->func_graph()); + MS_EXCEPTION_IF_NULL(cnode->func_graph()->manager()); + auto manager = cnode->func_graph()->manager(); + manager->SetEdge(cnode_load, 2, const_u_monad); + // Update the u_monad input of UpdateState from CNode U same as Load to constant U. + CNodePtr cnode_update_state = nullptr; + CNodePtr cnode_make_tuple = nullptr; + const auto &cnode_load_users = manager->node_users()[cnode_load]; + for (auto &load_user : cnode_load_users) { + if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) { + const auto &cnode_make_tuple_users = manager->node_users()[load_user.first]; + for (auto &make_tuple_user : cnode_make_tuple_users) { + if (IsPrimitiveCNode(make_tuple_user.first, prim::kPrimUpdateState)) { + const auto &cnode_user = make_tuple_user.first->cast(); + if (cnode_user->input(1) == cnode_u) { + cnode_update_state = cnode_user; + cnode_make_tuple = load_user.first->cast(); + break; + } + } + } + if (cnode_update_state != nullptr) { + break; + } + } + if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) { + const auto &cnode_user = load_user.first->cast(); + if (cnode_user->input(1) == cnode_u) { + cnode_update_state = cnode_user; + break; + } + } + } + if (cnode_update_state != nullptr) { + if (cnode_make_tuple == nullptr || cnode_make_tuple->inputs().size() == 2) { + // case 1 and case 3: Replace cnode_update_state to cnode_u; + MS_LOG(DEBUG) << "Replace UpdateState with CNode U: " << cnode_update_state->DebugString() + << " ::TO:: " << cnode_u->DebugString(); + manager->Replace(cnode_update_state, cnode_u); + } else if (cnode_make_tuple->inputs().size() > 2) { + // case 2: remove cnode_load from cnode_make_tuple; + MS_LOG(DEBUG) << "Drop " << cnode_load->DebugString() << " from " << cnode_make_tuple->DebugString(); + const auto &make_tuple_inputs = cnode_make_tuple->inputs(); + AnfNodePtrList new_tuple_inputs(make_tuple_inputs.size() - 1); + std::copy_if(make_tuple_inputs.cbegin(), make_tuple_inputs.cend(), new_tuple_inputs.begin(), + [cnode_load](const auto &inp) { return inp != cnode_load; }); + auto new_cnode_make_tuple = cnode_make_tuple->func_graph()->NewCNode(new_tuple_inputs); + manager->Replace(cnode_make_tuple, new_cnode_make_tuple); + } else { + MS_LOG(EXCEPTION) << "Cannot replace UpdateState with CNode U: " << cnode_update_state->DebugString() + << " as make_tuple CNode cannot match " << cnode_make_tuple->DebugString(); + } + } + } +} + AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr &func_graph, const CommunicationOpInfo &communication_op_info, size_t start_index, size_t end_index) const { @@ -184,6 +297,9 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr for (size_t idx = start_index; idx <= end_index; ++idx) { auto cnode = communication_op_info.communication_op_nodes[idx]; MS_EXCEPTION_IF_NULL(cnode); + if (idx != start_index) { + AdjustAllReduceInputWithLoad(cnode); + } fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); } CheckInputs(fusion_inputs); diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc index 6d4861e8cd..d7ffd33243 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_tensor_input.cc @@ -107,9 +107,7 @@ AnfNodePtr ProcessGraphKernelOp(const AnfNodePtr &node) { auto mng = sub_graph->manager(); MS_EXCEPTION_IF_NULL(mng); std::vector todo; - std::vector> graph_rets; kernel::GetValidKernelNodes(sub_graph, &todo); - kernel::GetGraphRealOutput(sub_graph, &graph_rets); for (auto &t : todo) { auto t_new_node = ConstInputToTensorInput(sub_graph, t->cast()); diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc index 6aeb671177..5465c1aa5a 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_input_to_dynamic_input.cc @@ -37,7 +37,8 @@ void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePt std::vector plant_inputs; std::vector dyn_input_sizes; plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr)); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode_ptr); ++i) { + size_t input_num = AnfAlgo::GetInputTensorNum(cnode_ptr); + for (size_t i = 0; i < input_num; ++i) { auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i); MS_EXCEPTION_IF_NULL(input_node); if (input_node->isa() && AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) { @@ -45,7 +46,8 @@ void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePt dyn_input_sizes.push_back(input_size); auto make_tuple = input_node->cast(); MS_EXCEPTION_IF_NULL(make_tuple); - for (size_t j = 0; j < AnfAlgo::GetInputTensorNum(make_tuple); ++j) { + size_t tuple_input_num = AnfAlgo::GetInputTensorNum(make_tuple); + for (size_t j = 0; j < tuple_input_num; ++j) { auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j); MS_EXCEPTION_IF_NULL(dyn_input_node); if (IsValueNode(dyn_input_node)) { diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc index 207407436b..53c34b7dc8 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_tuple_output_to_maketuple.cc @@ -65,7 +65,7 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func return nullptr; } } - if (IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { + if (IsPrimitiveCNode(cnode, prim::kPrimUpdateState)) { return nullptr; } bool cnode_input_changed = false; diff --git a/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc b/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc index 75183f10ab..abc7d24486 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/eliminate_redundant_op.cc @@ -58,6 +58,9 @@ CNodePtr GetRealPrevCNode(const AnfNodePtr &node, size_t index, std::vectorpush_back(make_pair(cnode, IntToSize(1))); return GetRealPrevCNode(cnode->input(1), 0, pass_vector); + } else if (IsPrimitive(input0, prim::kPrimUpdateState)) { + pass_vector->push_back(make_pair(cnode, IntToSize(kUpdateStateRealInput))); + return GetRealPrevCNode(cnode->input(kUpdateStateRealInput), 0, pass_vector); } else { return nullptr; } diff --git a/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.cc b/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.cc index 2b3e2b7e27..540310ab9a 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/getitem_tuple.cc @@ -45,9 +45,7 @@ const AnfNodePtr GetitemTuple::Process(const FuncGraphPtr &, const AnfNodePtr &n MS_EXCEPTION_IF_NULL(node); CNodePtr tuple_getitem = node->cast(); MS_EXCEPTION_IF_NULL(tuple_getitem); - if (tuple_getitem->inputs().size() < kTupleGetitemInputNum) { - MS_LOG(EXCEPTION) << "tuple getitem's input num is wrong"; - } + CheckCNodeInputSize(tuple_getitem, kTupleGetItemInputTensorNum); AnfNodePtr make_tuple_anf = tuple_getitem->input(kRealInputNodeIndexInTupleGetItem); MS_EXCEPTION_IF_NULL(make_tuple_anf); AnfNodePtr index_node = tuple_getitem->input(kInputNodeOutputIndexInTupleGetItem); diff --git a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc index 730cd5b83e..8fbed9c8a0 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/optimize_dependence.cc @@ -43,7 +43,7 @@ AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node if (!IsNotRealUsedByOthers(func_graph, cnode)) { return nullptr; } - CheckCNodeInputSize(cnode, kSingleInputIndex + 1); + CheckCNodeInputSize(cnode, kSingleInputIndex); return cnode->input(kSingleInputIndex); } @@ -55,7 +55,8 @@ AnfNodePtr ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnod } std::vector new_make_tuple_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; bool need_update = false; - for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { + size_t input_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t index = 0; index < input_num; ++index) { auto input = AnfAlgo::GetInputNode(cnode, index); AnfNodePtr replace_input = GetReplaceNode(func_graph, input); // If replace input is not null, it will be the input of the TransData or Cast. @@ -91,6 +92,29 @@ const BaseRef OptimizeDependence::DefinePattern() const { return VectorRef({X, Xs}); } +std::pair SearchTransDataAndCast(const AnfNodePtr &node, bool is_first_node) { + if (node == nullptr || !node->isa()) { + return std::pair(nullptr, 0); + } + // get real input of depend and update state. + size_t replace_input_index = 0; + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimDepend)) { + replace_input_index = is_first_node ? kDependAttachNodeIndex : kRealInputIndexInDepend; + } else if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimUpdateState)) { + replace_input_index = is_first_node ? kUpdateStateStateInput : kUpdateStateRealInput; + } else { + return std::pair(nullptr, 0); + } + // check whether real input is cast or trans data + auto real_input = node->cast()->input(replace_input_index); + if (AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimCast) || + AnfAlgo::CheckPrimitiveType(real_input, prim::KPrimTransData) || + AnfAlgo::CheckPrimitiveType(real_input, prim::kPrimMakeTuple)) { + return std::pair(node, replace_input_index); + } + return SearchTransDataAndCast(real_input, false); +} + const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const { MS_EXCEPTION_IF_NULL(func_graph); @@ -98,42 +122,36 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con if (!node->isa()) { return nullptr; } - auto node_name = AnfAlgo::GetCNodeName(node); - if (node_name != prim::kPrimControlDepend->name() && node_name != prim::kPrimDepend->name()) { + // Get the cnode with repalce input index + auto cnode_with_input_index = SearchTransDataAndCast(node, true); + if (cnode_with_input_index.first == nullptr) { return nullptr; } - size_t index = 0; - auto depend_cnode = node->cast(); + size_t replace_index = cnode_with_input_index.second; + auto depend_cnode = cnode_with_input_index.first->cast(); MS_EXCEPTION_IF_NULL(depend_cnode); - std::vector new_depend_inputs = {depend_cnode->input(kAnfPrimitiveIndex)}; - if (node_name == prim::kPrimDepend->name()) { - index = 1; - new_depend_inputs.push_back(depend_cnode->input(kRealInputIndexInDepend)); - } - if (AnfAlgo::GetInputTensorNum(depend_cnode) < 2) { - MS_LOG(EXCEPTION) << "The depend node input size is at less size 2,but got " - << AnfAlgo::GetInputTensorNum(depend_cnode) << depend_cnode->DebugString(); - } - auto input_num = AnfAlgo::GetInputTensorNum(depend_cnode); - while (index < input_num) { - auto replace_node = GetConvertNode(func_graph, node, index); - MS_EXCEPTION_IF_NULL(replace_node); - new_depend_inputs.push_back(replace_node); - ++index; + // Get new node which will act as new input of depend or UpdateState. + std::vector new_depend_inputs = depend_cnode->inputs(); + auto replace_node = GetConvertNode(func_graph, depend_cnode, replace_index); + if (replace_node == nullptr) { + return nullptr; } + new_depend_inputs[replace_index] = replace_node; + // Because depend's input has been changed, so a new depend(UpdateState) node will be created to replaced the old one. auto kernel_graph = func_graph->cast>(); CNodePtr new_depend = nullptr; if (kernel_graph == nullptr) { new_depend = func_graph->NewCNode(new_depend_inputs); MS_EXCEPTION_IF_NULL(new_depend); - new_depend->set_abstract(node->abstract()); - new_depend->set_scope(node->scope()); + new_depend->set_abstract(depend_cnode->abstract()); + new_depend->set_scope(depend_cnode->scope()); } else { new_depend = kernel_graph->NewCNode(depend_cnode); MS_EXCEPTION_IF_NULL(new_depend); new_depend->set_inputs(new_depend_inputs); } - return new_depend; + func_graph->manager()->Replace(depend_cnode, new_depend); + return nullptr; } const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, const AnfNodePtr &node, @@ -141,10 +159,10 @@ const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, c MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(node); auto depend_cnode = node->cast(); - auto replacing_node = AnfAlgo::GetInputNode(depend_cnode, index); + auto replacing_node = depend_cnode->input(index); MS_EXCEPTION_IF_NULL(replacing_node); if (!replacing_node->isa()) { - return replacing_node; + return nullptr; } auto replacing_cnode = replacing_node->cast(); MS_EXCEPTION_IF_NULL(replacing_cnode); @@ -154,10 +172,6 @@ const AnfNodePtr OptimizeDependence::GetConvertNode(const FuncGraphPtr &graph, c return make_tuple_replace_node; } AnfNodePtr replace_node = GetReplaceNode(graph, replacing_cnode); - if (replace_node == nullptr) { - MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); - return replacing_node; - } return replace_node; } } // namespace opt diff --git a/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc b/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc index 8c6dab79be..d4fe31013f 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/replace_node_by_proxy.cc @@ -30,11 +30,13 @@ kernel::KernelBuildInfoPtr ReplaceNodeByProxy::GenerateKernelBuildInfo(const CNo std::vector outputs_device_type; std::vector> outputs_shape; kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(cnode); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_device_format.push_back(AnfAlgo::GetInputFormat(cnode, input_index)); inputs_device_type.push_back(AnfAlgo::GetInputDeviceDataType(cnode, input_index)); } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); + for (size_t output_index = 0; output_index < output_num; ++output_index) { outputs_device_format.push_back(AnfAlgo::GetOutputFormat(cnode, output_index)); outputs_device_type.push_back(AnfAlgo::GetOutputDeviceDataType(cnode, output_index)); outputs_shape.push_back(AnfAlgo::GetOutputInferShape(cnode, output_index)); diff --git a/mindspore/ccsrc/backend/session/CMakeLists.txt b/mindspore/ccsrc/backend/session/CMakeLists.txt index e67fcbc7b8..7fabffee04 100644 --- a/mindspore/ccsrc/backend/session/CMakeLists.txt +++ b/mindspore/ccsrc/backend/session/CMakeLists.txt @@ -26,6 +26,7 @@ if(ENABLE_D) file(GLOB_RECURSE _D_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "ascend_session.cc" "ascend_control_parser.cc" + "ascend_auto_monad.cc" "ascend_inference_session.cc" ) list(APPEND _SESSION_SRC_LIST ${_D_SRC_LIST}) diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index b24ef0557a..bdd99e0c7d 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -18,7 +18,7 @@ #include #include #include -#include +#include #include "ir/anf.h" #include "ir/func_graph.h" #include "base/core_ops.h" @@ -48,6 +48,10 @@ namespace { constexpr size_t kNopNodeInputSize = 2; constexpr size_t kNopNodeRealInputIndex = 1; +using PrimitiveSet = std::unordered_set; + +PrimitiveSet follow_first_input_prims = {prim::kPrimDepend, prim::kPrimLoad}; + bool IsShapeDynamic(const abstract::ShapePtr &shape) { MS_EXCEPTION_IF_NULL(shape); return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; }); @@ -57,6 +61,33 @@ bool IsShapeDynamic(const std::vector &shape) { return std::any_of(shape.begin(), shape.end(), [](int64_t s) { return s < 0; }); } +bool IsOneOfPrimitive(const AnfNodePtr &node, const PrimitiveSet &prim_set) { + PrimitivePtr prim = GetValueNode(node); + return (prim && prim_set.find(prim) != prim_set.end()); +} + +bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set) { + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + if (cnode == nullptr || cnode->size() == 0) { + return false; + } + return IsOneOfPrimitive(cnode->inputs().at(kAnfPrimitiveIndex), prim_set); +} + +bool IsRealKernelCNode(const CNodePtr &cnode) { + static const PrimitiveSet virtual_prims = { + prim::kPrimImageSummary, prim::kPrimScalarSummary, prim::kPrimTensorSummary, prim::kPrimHistogramSummary, + prim::kPrimMakeTuple, prim::kPrimStateSetItem, prim::kPrimTupleGetItem, prim::kPrimReturn, + prim::kPrimPartial, prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad}; + if (cnode->inputs().empty()) { + MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << cnode->DebugString(); + } + const auto &input = cnode->inputs().at(0); + bool is_virtual_node = IsOneOfPrimitive(input, virtual_prims); + return !is_virtual_node; +} + std::vector TransShapeToSizet(const abstract::ShapePtr &shape) { MS_EXCEPTION_IF_NULL(shape); std::vector shape_size_t; @@ -121,7 +152,9 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, siz MS_EXCEPTION_IF_NULL(value_node); auto item_idx = GetValue(value_node->value()); return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), LongToSize(item_idx)); - } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { + } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) { + return VisitKernel(cnode->input(kUpdateStateRealInput), 0); + } else if (IsOneOfPrimitive(input0, follow_first_input_prims)) { return VisitKernel(cnode->input(kRealInputIndexInDepend), 0); } else { return std::make_pair(anf_node, index); @@ -162,7 +195,10 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr } return item_with_index_tmp; } - if (CheckPrimitiveType(cnode, prim::kPrimDepend) || CheckPrimitiveType(cnode, prim::kPrimControlDepend)) { + if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimUpdateState)) { + return VisitKernelWithReturnType(cnode->input(kUpdateStateStateInput), index, visit_nop_node, return_types); + } + if (IsOneOfPrimitiveCNode(cnode, follow_first_input_prims)) { return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), index, visit_nop_node, return_types); } if (opt::IsNopNode(cnode) && visit_nop_node) { @@ -347,21 +383,48 @@ bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &no return fg->has_attr(key); } +size_t AnfRuntimeAlgorithm::GetInputNum(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + size_t input_num = cnode->size(); + if (input_num == 0) { + MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero"; + } + return input_num - 1; +} + size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { + auto cnode = node->cast(); + if (cnode == nullptr) { MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString() << " trace: " << trace::DumpSourceLines(node); } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); + ssize_t input_tensor_num = cnode->input_tensor_num(); + if (input_tensor_num >= 0) { + return static_cast(input_tensor_num); + } size_t input_num = cnode->inputs().size(); if (input_num == 0) { MS_LOG(EXCEPTION) << "Cnode inputs size can't be zero" << " trace: " << trace::DumpSourceLines(node); } - // exclude inputs[0],which is value_node storing attr,inputs left are real input - return input_num - 1; + // Exclude inputs[0]. + --input_num; + + // Exclude monad inputs for real cnodes. + if (input_num > 0 && IsRealKernelCNode(cnode)) { + auto &inputs = cnode->inputs(); + // Search monad inputs, backward. + for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { + if (!HasAbstractMonad(*iter)) { + // Stop count if we encounter a non-monad input. + break; + } + --input_num; + } + } + cnode->set_input_tensor_num(static_cast(input_num)); + return input_num; } size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { @@ -374,13 +437,11 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { auto tuple_type = type->cast(); MS_EXCEPTION_IF_NULL(tuple_type); return tuple_type->size(); - } else if (type->isa() || type->isa()) { - return 1; - } else if (type->isa()) { + } + if (type->isa()) { return 0; - } else { - return 1; } + return 1; } std::vector AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodePtr &node) { @@ -986,14 +1047,7 @@ bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) { MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString() << " trace: " << trace::DumpSourceLines(node); } - auto input = cnode->inputs()[0]; - bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) || - IsPrimitive(input, prim::kPrimTensorSummary) || - IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) || - IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) || - IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) || - IsPrimitive(input, prim::kPrimReturn) || IsPrimitive(input, prim::kPrimPartial); - return !is_virtual_node; + return IsRealKernelCNode(cnode); } bool AnfRuntimeAlgorithm::IsRealCNodeKernel(const AnfNodePtr &node) { @@ -1120,7 +1174,7 @@ AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { + if (node->isa() || IsPrimitiveCNode(node, prim::kPrimLoad)) { return false; } auto kernel_info = static_cast(node->kernel_info()); @@ -1297,7 +1351,7 @@ bool AnfRuntimeAlgorithm::IsScalarOutput(const CNodePtr &cnode, size_t index) { return shape.size() == kShape1dDims && shape[0] == 1; } -void AnfRuntimeAlgorithm::ReorderExecList(NotNull *> node_list) { +void AnfRuntimeAlgorithm::ReorderOptimizerExecList(NotNull *> node_list) { std::vector all_opt_list; std::vector non_opt_list; std::vector trans_list; @@ -1372,6 +1426,23 @@ void AnfRuntimeAlgorithm::ReorderExecList(NotNull *> node_ std::copy(cast_list.begin(), cast_list.end(), std::back_inserter(*node_list)); } +void AnfRuntimeAlgorithm::ReorderPosteriorExecList(NotNull *> node_list) { + std::vector ordinary_node_list; + std::vector posterior_node_list; + + for (const auto &node : *node_list) { + MS_EXCEPTION_IF_NULL(node); + if (kPosteriorOperatorSet.find(AnfAlgo::GetCNodeName(node)) != kPosteriorOperatorSet.end()) { + posterior_node_list.emplace_back(node); + } else { + ordinary_node_list.emplace_back(node); + } + } + node_list->clear(); + std::copy(ordinary_node_list.begin(), ordinary_node_list.end(), std::back_inserter(*node_list)); + std::copy(posterior_node_list.begin(), posterior_node_list.end(), std::back_inserter(*node_list)); +} + TypeId AnfRuntimeAlgorithm::GetCNodeOutputPrecision(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto prim = AnfAlgo::GetCNodePrimitive(node); diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 0e3a876527..3ea2c92f3c 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -97,7 +97,9 @@ class AnfRuntimeAlgorithm { static bool HasNodeAttr(const std::string &key, const CNodePtr &node); // delete attr of anf node static void EraseNodeAttr(const std::string &key, AnfNodePtr node); - // get the num of input real_kernel(which can be build and run in device) + // get the num of inputs include monads for a cnode + static size_t GetInputNum(const CNodePtr &cnode); + // get the num of inputs exclude monads for real_kernel (which can be build and run in device) static size_t GetInputTensorNum(const AnfNodePtr &node); // get the num of output real_kernel(which can be build and run in device) static size_t GetOutputTensorNum(const AnfNodePtr &node); @@ -221,7 +223,8 @@ class AnfRuntimeAlgorithm { static bool IsSwitchCall(const CNodePtr &call_node); static bool IsScalarInput(const CNodePtr &cnode, size_t index); static bool IsScalarOutput(const CNodePtr &cnode, size_t index); - static void ReorderExecList(NotNull *> node_list); + static void ReorderOptimizerExecList(NotNull *> node_list); + static void ReorderPosteriorExecList(NotNull *> node_list); // get fix output precision of cnode. static TypeId GetCNodeOutputPrecision(const AnfNodePtr &node); // get fix output precision from prev node, input_idx is the input index of current node related to prev node. diff --git a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc new file mode 100644 index 0000000000..ef7e7267cc --- /dev/null +++ b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc @@ -0,0 +1,981 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/session/ascend_auto_monad.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "utils/ms_context.h" +#include "base/core_ops.h" +#include "debug/anf_ir_dump.h" +#include "pipeline/jit/base.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace session { +namespace { + +// Pair of graph and its actual arguments. +using GraphArgPair = std::pair>; + +// We start label id from 1, and use 0 to indicate label not set. +constexpr uint32_t kNoLabel = 0; + +// Primitive attribute for argument link assign. +const char LINK[] = "link"; + +bool IsSaveGraph() { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + return context_ptr->get_param(MS_CTX_SAVE_GRAPHS_FLAG); +} + +void DumpAllGraphs(NotNull kg, std::set *memo) { + if (memo->find(kg) != memo->end()) { + return; + } + memo->insert(kg); + std::string file_name = "ascend_auto_monad_" + std::to_string(kg->graph_id()) + ".ir"; + DumpIR(file_name, kg.get()); + for (auto &child : kg->child_graph_order()) { + auto cg = child.lock(); + if (cg) { + DumpAllGraphs(NOT_NULL(cg), memo); + } + } +} + +void DumpGraphForDebug(NotNull kg) { + if (IsSaveGraph()) { + std::set memo; + DumpAllGraphs(kg, &memo); + } +} + +void DumpExecuteOrder(NotNull kg) { + if (!IsSaveGraph()) { + return; + } + std::string filename = "ascend_execute_order_" + std::to_string(kg->graph_id()) + ".dat"; + auto filepath = pipeline::GetSaveGraphsPathName(filename); + char real_path[PATH_MAX] = {0}; +#if defined(_WIN32) || defined(_WIN64) + if (_fullpath(filepath, filename.c_str(), PATH_MAX) == nullptr) { + MS_LOG(DEBUG) << "dir " << filename << " does not exit."; + } +#else + if (realpath(filepath.c_str(), real_path) == nullptr) { + MS_LOG(DEBUG) << "Dir " << filepath << " does not exit."; + } +#endif + + std::ofstream fout(real_path); + if (!fout.is_open()) { + MS_LOG(ERROR) << "Open file '" << real_path << "' failed!"; + return; + } + + fout << "Execute order:\n"; + int index = 0; + for (auto &cnode : kg->execution_order()) { + MS_EXCEPTION_IF_NULL(cnode); + if (IsPrimitiveCNode(cnode, prim::kPrimLabelSet)) { + fout << "L" << AnfAlgo::GetNodeAttr(cnode, kAttrLabelIndex) << ":\n"; + } + fout << " [" << index << "], " << cnode->DebugString(); + if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cnode)) { + fout << " : L" << AnfAlgo::GetNodeAttr(cnode, kAttrLabelIndex); + } + if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cnode)) { + auto labels = AnfAlgo::GetNodeAttr>(cnode, kAttrLabelSwitchList); + fout << " : "; + for (size_t i = 0; i < labels.size(); ++i) { + fout << ((i > 0) ? ", L" : "L") << labels[i]; + } + } + fout << '\n'; + index++; + } + fout.close(); +} + +// +// ParameterPool cache parameters by its abstract, so that we can reuse +// parameter with same abstract to store return values. +// +class ParameterPool { + public: + explicit ParameterPool(const KernelGraphPtr &top_graph) : top_graph_(top_graph) {} + ~ParameterPool() = default; + + // Create or get a parameter from pool with the given abstract. + AnfNodePtr GetParameter(const abstract::AbstractBasePtr &abs) { + // Find parameter in pool by the given abstract. + auto iter = std::find_if(paras_.begin(), paras_.end(), [&abs](auto ¶) { + auto para_abs = para->abstract(); + // Reuse output parameter with compatible abstract. + return IsCompatible(abs, para_abs); + }); + // Return the parameter if found. + if (iter != paras_.end()) { + return *iter; + } + // If parameter not found with the given abstract, create a new one. + auto para = top_graph_->NewParameter(abs); + auto out_para = top_graph_->TransTupleToMakeTuple(para); + // This is required, so that device memory can be allocated for it. + top_graph_->AddChildGraphResult(out_para); + // Save new para to pool. + paras_.push_back(out_para); + return out_para; + } + + protected: + // Check if one abstract is compatible with another abstract. + static bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) { + if (a1 == nullptr || a2 == nullptr) { + return false; + } + if (a1->isa() && a2->isa()) { + // This make AbstractRef compatible with AbstractTensor. + auto &t1 = static_cast(*a1); + auto &t2 = static_cast(*a2); + return t1 == t2; + } + return *a1 == *a2; + } + + private: + // The top graph. + const KernelGraphPtr &top_graph_; + + // Cached parameters. + std::vector paras_; +}; + +using ParameterPoolPtr = std::shared_ptr; + +class BaseContext { + public: + void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); } + + bool IsVisited(const KernelGraphPtr &kg) const { return visited_graphs_.find(kg) != visited_graphs_.end(); } + + const std::set &visited_graphs() const { return visited_graphs_; } + + private: + std::set visited_graphs_; +}; + +// +// AscendAutoMonadContext holds some shared states during auto-moand. +// +class AscendAutoMonadContext : public BaseContext { + public: + explicit AscendAutoMonadContext(const KernelGraphPtr &kg) : top_graph_(kg) {} + ~AscendAutoMonadContext() = default; + + // Label id start from 1, and increased by 1 for each new id. + uint32_t NewLabel() { return label_id_++; } + + // Current label id, also the number of label ids we currently used. + uint32_t CurrentLabel() const { return label_id_; } + + // Create a new parameter pool. + ParameterPoolPtr NewParameterPool() { return std::make_shared(top_graph_); } + + private: + // The top graph. + const KernelGraphPtr &top_graph_; + + // Current label id. + uint32_t label_id_ = 1; +}; + +// +// AscendAutoMonadConverter convert control flow to monad form +// for a kernel graph and its children graphs recursively. +// +class AscendAutoMonadConverter { + public: + AscendAutoMonadConverter(AscendAutoMonadContext *context, const KernelGraphPtr &kg) + : context_(*context), kernel_graph_(kg) {} + + ~AscendAutoMonadConverter() = default; + + void Run() { + // Skip if graph already visited. + if (context_.IsVisited(kernel_graph_)) { + return; + } + context_.MarkVisited(kernel_graph_); + + // Update directly called sub-graphs. + kernel_graph_->UpdateChildGraphOrder(); + + Prepare(); + + // Setup entry label if needed. + auto entry_label = GetGraphLabel(kernel_graph_); + if (entry_label != kNoLabel) { + SetupEntryLabel(entry_label); + } + + // Handle call and switch nodes. + HandleCallSwitch(); + + // Let output depend on monad. + if (monad_) { + MakeMonadDepend(); + } + } + + private: + // + // Prepare information for control flow processing. + // + void Prepare() { + AnfNodePtr last_monad = nullptr; + auto nodes = TopoSort(kernel_graph_->output()); + for (auto &node : nodes) { + MS_EXCEPTION_IF_NULL(node); + if (HasAbstractUMonad(node)) { + // Found a node with UMonad abstract, set it as the last monad. + last_monad = node; + } + auto cnode = node->cast(); + if (cnode == nullptr) { + continue; + } + if (cnode->size() < 1) { + MS_LOG(EXCEPTION) << "Invalid CNode: " << cnode->DebugString() << std::endl; + } + if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) || + AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { + // Found call/switch node, set it as the tail call node. + tail_call_node_ = cnode; + call_switch_nodes_.emplace_back(cnode); + monad_map_.emplace(cnode, last_monad); + } else if (tail_call_node_ != nullptr && AnfAlgo::IsRealKernel(cnode)) { + // Set no tail call if we found real kernel cnode after call/switch. + tail_call_node_ = nullptr; + } + } + } + + // + // Handle call and switch node, return true if tail call found. + // + void HandleCallSwitch() { + // Handle call switch nodes. + for (auto &cnode : call_switch_nodes_) { + if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) { + HandleCall(cnode); + } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { + HandleSwitch(cnode); + } else { + MS_LOG(EXCEPTION) << "Not a call/switch node: " << cnode->DebugString(); + } + } + // If no tail call, assign output value to output parameter, + // and then goto the return label if set. + if (tail_call_node_ == nullptr) { + if (output_parameter_) { + auto assign_output = AssignAll(output_parameter_, kernel_graph_->output()); + monad_ = UpdateState(GetMonad(), assign_output); + } + if (return_label_ != kNoLabel) { + (void)LabelGoto(return_label_); + } + } + } + + // + // Convert call node: + // out = Call(graph, arg) + // to: + // r = link_args(graph.para, arg, c) + // c = UpdateState(c, r) + // c = LabelGoto(c) : L1 + // + void HandleCall(const CNodePtr &cnode) { + // Update last_monad_. + last_monad_ = monad_map_[cnode]; + + // The callee graph. + auto graph = GetCallGraph(cnode); + MS_EXCEPTION_IF_NULL(graph); + + // Link arguments for the sub-graph. + constexpr size_t call_arg_index = 2; + auto &inputs = cnode->inputs(); + std::vector args(inputs.begin() + call_arg_index, inputs.end()); + auto linked_args = LinkArguments(args, graph); + if (linked_args != nullptr) { + monad_ = UpdateState(GetMonad(), linked_args); + } + + // Goto sub-graph label. + uint32_t graph_label = GetOrCreateGraphLabel(graph); + auto goto_node = LabelGoto(graph_label); + + // Set child graph attribute, so that subsequence steps such + // as 'select kernel' can handle sub graphs. + SetChildGrapAttr(goto_node, {graph}); + + // Setup return label if this is not a tail call. + const bool is_tail_call = (cnode == tail_call_node_); + const bool need_return = !is_tail_call; + auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return); + + // Handle sub-graph recursively. + HandleSubGraph(graph, para_pool, output_para, return_label); + } + + // + // Convert switch node: + // branch1 = Partial(graph1, arg) + // branch2 = Partial(graph2, arg) + // out = Switch(cond, branch1, branch2) + // to: + // r = link_args(graph1, arg) + // c = UpdateState(c, r) + // r = link_args(graph2, arg) + // c = UpdateState(c, r) + // c = LabelSwitch(cond, c) : L1, L2 + // c = LabelSet(c) : + // + void HandleSwitch(const CNodePtr &cnode) { + // Update last_monad_. + last_monad_ = monad_map_[cnode]; + + // Get both branches of the switch, true branch first. + auto branches = GetSwitchBranches(cnode); + + // Link arguments and generate labels for branches. + std::vector graphes; + std::vector labels; + graphes.reserve(branches.size()); + labels.reserve(graphes.size()); + for (auto &[graph, args] : branches) { + if (graph == nullptr) { + MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString(); + } + auto linked_args = LinkArguments(args, graph); + if (linked_args != nullptr) { + monad_ = UpdateState(GetMonad(), linked_args); + } + graphes.push_back(graph); + labels.push_back(GetOrCreateGraphLabel(graph)); + } + + // Since true/false branches is reversed in kernel LabelSwitch, + // We reverse graphes and labels to make false branch first. + std::reverse(graphes.begin(), graphes.end()); + std::reverse(labels.begin(), labels.end()); + + // Add LabelSwith node. + auto switch_node = LabelSwitch(cnode->input(1), labels); + + // Set child graph attribute for switch node. + SetChildGrapAttr(switch_node, graphes); + + // Setup return label if required. + const bool is_tail_call = (cnode == tail_call_node_); + const bool need_return = (return_label_ == kNoLabel || !is_tail_call); + auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return); + + // Handle sub-graphs recursively. + for (auto &graph : graphes) { + HandleSubGraph(graph, para_pool, output_para, return_label); + } + } + + ParameterPoolPtr GetParameterPool(bool is_last_call) { + if (!is_last_call) { + // There are multiple calls in this graph, use a new parameter pool + // for each of them except the last one. + return context_.NewParameterPool(); + } + // For last call, try reuse parameter pool from the caller. + if (para_pool_ == nullptr) { + para_pool_ = context_.NewParameterPool(); + } + return para_pool_; + } + + // Make return part of a call for the LabelGoto/LabelSwitch node. + std::tuple MakeReturn(const CNodePtr &cnode, bool need_return) { + // Find a parameter pool for output parameter. + const bool is_last_call = (cnode == call_switch_nodes_.back()); + auto para_pool = GetParameterPool(is_last_call); + + // Prepare return label and output parameter. + uint32_t return_label = return_label_; + auto output_para = para_pool->GetParameter(cnode->abstract()); + auto output = output_para; + + // Setup return label if return is required. + if (need_return) { + // Set a new label at return point. + return_label = context_.NewLabel(); + auto label_node = LabelSet(return_label); + // Let output depend on the label node, this ensures the + // return label is set before output is used. + output = MakeDepend(output, label_node); + } + + // Replace the the switch node with the output. + kernel_graph_->ReplaceNode(NOT_NULL(cnode), NOT_NULL(output)); + return {para_pool, output_para, return_label}; + } + + // Handle sub-graphs recursively. + void HandleSubGraph(const KernelGraphPtr &graph, const ParameterPoolPtr ¶_pool, const AnfNodePtr &out_para, + uint32_t return_label) { + AscendAutoMonadConverter converter(&context_, graph); + converter.para_pool_ = para_pool; + converter.output_parameter_ = out_para; + converter.return_label_ = return_label; + converter.Run(); + } + + KernelGraphPtr GetCallGraph(const CNodePtr &cnode) { + auto input_graph = cnode->input(kCallKernelGraphIndex); + MS_EXCEPTION_IF_NULL(input_graph); + return GetValueNode(input_graph); + } + + GraphArgPair GetSwitchBranch(const CNodePtr &cnode, size_t index) { + auto partial_cnode = dyn_cast(cnode->input(index)); + if (partial_cnode == nullptr) { + return {nullptr, {}}; + } + auto &inputs = partial_cnode->inputs(); + if (!IsPrimitive(inputs.at(0), prim::kPrimPartial)) { + MS_LOG(EXCEPTION) << "Invalid switch node: " << cnode->DebugString(); + } + auto graph = GetValueNode(inputs.at(1)); + constexpr size_t arg_index = 2; + return {graph, {inputs.begin() + arg_index, inputs.end()}}; + } + + std::vector GetSwitchBranches(const CNodePtr &cnode) { + constexpr size_t true_index = 2; + constexpr size_t false_index = 3; + // True branch first, then false branch. + return {GetSwitchBranch(cnode, true_index), GetSwitchBranch(cnode, false_index)}; + } + + // + // Link actual arguments to graph's formal arguments. + // for multi-args: + // r = Call(fg, arg1, arg2, u) + // linked arguments: + // r1 = Assign(para1, arg1, c) + // r2 = Assign(para2, arg2, c) + // tuple = MakeTuple(r1, r2, u) + // + // for single-arg: + // r = Call(fg, arg) + // linked arguments: + // r = Assign(para1, arg1, c) + // + // for empty-arg: + // r = Call(fg) + // linked arguments return null. + // + AnfNodePtr LinkArguments(const std::vector &args, const KernelGraphPtr &graph) { + auto ¶s = graph->inputs(); + if (args.size() != paras.size()) { + MS_LOG(EXCEPTION) << "Wrong arg number! " << graph->ToString() << " " << args.size() << " != " << paras.size(); + } + // If no argument, return null. + if (args.empty()) { + return nullptr; + } + // Single argument. + if (args.size() == 1) { + auto &value = args.front(); + if (HasAbstractMonad(value) || paras.front() == value) { + // No assign for single monad argument, return it. + return value; + } + return Assign(paras.front(), value, true); + } + // Multi arguments. + AnfNodePtrList tuple_inputs; + tuple_inputs.reserve(args.size() + 1); + tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t i = 0; i < args.size(); ++i) { + auto &value = args.at(i); + if (HasAbstractMonad(value)) { + // No assign for monad arguments. + tuple_inputs.emplace_back(value); + continue; + } + // Assign general arguments. + auto &target = paras.at(i); + if (target == value) { + continue; + } + tuple_inputs.emplace_back(Assign(target, value, true)); + } + return kernel_graph_->NewCNode(tuple_inputs); + } + + // For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode. + AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared(prim->name())); } + + AnfNodePtr GetAssignMonad() { + if (last_monad_ != nullptr) { + return last_monad_; + } + return GetMonadValue(); + } + + // Make a assign cnode. + CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) { + auto monad = GetAssignMonad(); + auto assign_prim = std::make_shared(prim::kPrimAssign->name()); + if (is_link) { + // Mark this assign is to link real argument to formal argument. + assign_prim->set_attr(LINK, prim::kValueOne); + } + auto assign = NewValueNode(assign_prim); + auto cnode = kernel_graph_->NewCNode({assign, target, source, monad}); + cnode->set_abstract(target->abstract()); + return cnode; + } + + // AissgnAll support tuple to tuple assign. + AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source) { + if (!AnfAlgo::CheckPrimitiveType(target, prim::kPrimMakeTuple)) { + // Assign single value. + return Assign(target, source); + } + // Assign tuple. + std::vector targets = AnfAlgo::GetAllOutput(target, {prim::kPrimTupleGetItem}); + std::vector sources = AnfAlgo::GetAllOutput(source, {prim::kPrimTupleGetItem}); + if (targets.size() != sources.size()) { + MS_LOG(EXCEPTION) << "Target size " << targets.size() << " != source size " << sources.size(); + } + AnfNodePtrList tuple_inputs; + tuple_inputs.reserve(targets.size() + 1); + tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t i = 0; i < targets.size(); ++i) { + tuple_inputs.emplace_back(Assign(targets[i], sources[i])); + } + return kernel_graph_->NewCNode(tuple_inputs); + } + + AnfNodePtr UpdateState(const AnfNodePtr &state, const AnfNodePtr &input) { + auto update_state = NewValueNode(prim::kPrimUpdateState); + auto update_state_cnode = kernel_graph_->NewCNode({update_state, state, input}); + update_state_cnode->set_abstract(state->abstract()); + return update_state_cnode; + } + + // + // Make entry label for current graph. + // from: + // def sub_graph(x, y): + // return add(x, y) + // to: + // def sub_graph(x, y, c): + // c = LabelSet(c) : entry_label + // return add(x, y) + // + void SetupEntryLabel(uint32_t entry_label) { + // Set entry label. + auto label_node = LabelSet(entry_label); + // Make start label the first one in execution order. + kernel_graph_->set_start_label(label_node); + } + + // Make a Depend cnode. + CNodePtr MakeDepend(const AnfNodePtr &origin, const AnfNodePtr &input) { + auto depend = NewValueNode(prim::kPrimDepend); + auto depend_cnode = kernel_graph_->NewCNode({depend, origin, input}); + depend_cnode->set_abstract(origin->abstract()); + return depend_cnode; + } + + // Let output depend on monad. + void MakeMonadDepend() { + auto monad = GetMonad(); + auto origin_output = kernel_graph_->output(); + MS_EXCEPTION_IF_NULL(origin_output); + auto depend_cnode = MakeDepend(origin_output, monad); + kernel_graph_->set_output(depend_cnode); + } + + // Gets the last monad node, we use a separated UMonad for control flow. + AnfNodePtr &GetMonad() { + if (monad_ == nullptr) { + monad_ = GetMonadValue(); + } + return monad_; + } + + // Gets the monad const value node. + AnfNodePtr &GetMonadValue() { + if (monad_value_ == nullptr) { + // We should create monad value node by kernel graph, + // so that kernel_info is properly set for it. + monad_value_ = kernel_graph_->NewValueNode(kUMonad->ToAbstract(), kUMonad); + } + return monad_value_; + } + + // Make a LabelGoto node. + CNodePtr LabelGoto(uint32_t label_id) { + auto monad = GetMonad(); + auto label_goto = NewPrimitive(prim::kPrimLabelGoto); + auto cnode = kernel_graph_->NewCNode({label_goto, monad}); + AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(label_id), cnode); + cnode->set_abstract(monad->abstract()); + kernel_graph_->set_end_goto(cnode); // make 'goto' the last one in execute order. + monad_ = cnode; + return cnode; + } + + // Make a LabelSet node. + CNodePtr LabelSet(uint32_t label_id) { + auto monad = GetMonad(); + auto label_set = NewPrimitive(prim::kPrimLabelSet); + auto cnode = kernel_graph_->NewCNode({label_set, monad}); + AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(label_id), cnode); + cnode->set_abstract(monad->abstract()); + monad_ = cnode; + return cnode; + } + + // Make a LabelSwitch node. + CNodePtr LabelSwitch(const AnfNodePtr &cond, const std::vector &labels) { + auto monad = GetMonad(); + auto label_switch = NewPrimitive(prim::kPrimLabelSwitch); + auto cnode = kernel_graph_->NewCNode({label_switch, cond, monad}); + auto label_list = MakeValue(labels); + AnfAlgo::SetNodeAttr(kAttrLabelSwitchList, label_list, cnode); + cnode->set_abstract(monad->abstract()); + monad_ = cnode; + return cnode; + } + + // Return kNoLabel when label id attribute not set for the graph. + uint32_t GetGraphLabel(const KernelGraphPtr &kg) { + auto value = kg->get_attr(kAttrLabelIndex); + if (value == nullptr) { + return kNoLabel; + } + return GetValue(value); + } + + // Get or create entry label for the given graph. + uint32_t GetOrCreateGraphLabel(const KernelGraphPtr &kg) { + auto label = GetGraphLabel(kg); + if (label == kNoLabel) { + // Allocate a new label id and save it to the graph. + label = context_.NewLabel(); + kg->set_attr(kAttrLabelIndex, MakeValue(label)); + } + return label; + } + + void SetChildGrapAttr(const AnfNodePtr &node, const std::vector &graphs) { + AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue(graphs), node); + } + + private: + AscendAutoMonadContext &context_; + const KernelGraphPtr kernel_graph_; + + // Tail call node, null if not found. + CNodePtr tail_call_node_; + + // Call/Switch nodes. + std::vector call_switch_nodes_; + + // Call/Switch node to monad map. + std::map monad_map_; + + // The last monad for Call/Switch node. + AnfNodePtr last_monad_; + + // The current control flow monad. + AnfNodePtr monad_; + + // The control flow monad const value node. + AnfNodePtr monad_value_; + + // Parameter to store the return value. + AnfNodePtr output_parameter_; + + // Parameter pool for output parameter allocation. + ParameterPoolPtr para_pool_; + + // The return label id. + uint32_t return_label_ = kNoLabel; +}; + +constexpr size_t kAssignTargetIndex = 1; +constexpr size_t kAssignSourceIndex = 2; + +class ExecuteOrderGenerator { + public: + class Context : public BaseContext {}; + ExecuteOrderGenerator(Context &context, const KernelGraphPtr &graph) : context_(context), graph_(graph) {} + ~ExecuteOrderGenerator() = default; + + void Run() { + GenerateExecuteOrder(); + EraseParameter(); + EraseLabel(); + } + + private: + void GenerateGraphOrder(const KernelGraphPtr &graph) { + ExecuteOrderGenerator generator(context_, graph); + generator.GenerateExecuteOrder(); + } + + void AppendGraphOrder(std::vector *execution_order, const KernelGraphPtr &graph) { + auto &order = graph->execution_order(); + execution_order->insert(execution_order->end(), order.begin(), order.end()); + } + + bool HasSubGraphs(const CNodePtr &cnode) { return (cnode && AnfAlgo::HasNodeAttr(kAttrChildGraph, cnode)); } + + std::vector GetSubGraphs(const CNodePtr &cnode) { + return AnfAlgo::GetNodeAttr>(cnode, kAttrChildGraph); + } + + void EraseNodeFromExecOrder(const AnfNodePtr &node, const NotNull *> exec_order) { + MS_EXCEPTION_IF_NULL(node); + auto exec_iter = std::find(exec_order->begin(), exec_order->end(), node); + if (exec_iter == exec_order->end()) { + MS_LOG(EXCEPTION) << "Cannot find " << node->DebugString() << " in exec order."; + } + exec_order->erase(exec_iter); + } + + void GenerateExecuteOrder() { + // Mark graph is visited. + context_.MarkVisited(graph_); + + // Generate topo-sorted kernel cnodes list for this graph. + graph_->SetExecOrderByDefault(); + + std::vector execution_order; + const auto &cnodes = graph_->execution_order(); + for (auto cnode : cnodes) { + // Push current node to execution order list. + execution_order.push_back(cnode); + // For cnode with sub-graphs, such as LabelSwitch, LabelGoto, + // Generate execute order for these sub-graphs, + // and then append them to current execution order list. + if (HasSubGraphs(cnode)) { + // We use reversed order to generate sub-graph's execution order, + // because the true branch of LabelSwitch is the second one, but + // we want to make true branch ahead of false branch in the generated + // execution order. + auto sub_graphs = GetSubGraphs(cnode); + for (auto iter = sub_graphs.rbegin(); iter != sub_graphs.rend(); iter++) { + auto &sub_graph = *iter; + if (context_.IsVisited(sub_graph)) { + // Skip visited sub-graphs. + continue; + } + GenerateGraphOrder(sub_graph); + AppendGraphOrder(&execution_order, sub_graph); + } + // Clear ChildGraph attribute after execute order generated. + AnfAlgo::EraseNodeAttr(kAttrChildGraph, cnode); + } + } + // Save generated execution order into the graph. + graph_->set_execution_order(std::move(execution_order)); + } + + static const AnfNodePtr &GetRealNode(const AnfNodePtr &input) { + if (IsPrimitiveCNode(input, prim::kPrimLoad) || IsPrimitiveCNode(input, prim::kPrimDepend)) { + return input->cast()->inputs().at(1); + } + return input; + } + + // Erase redundant parameters and assign nodes. + void EraseParameter() { + // Copy out execution order list. + auto exec_order = graph_->execution_order(); + + // Remove assigns that target and source are same. + for (auto iter = exec_order.begin(); iter != exec_order.end();) { + auto &node = *iter; + auto &inputs = node->inputs(); + if (IsPrimitiveCNode(node, prim::kPrimAssign) && + (inputs.at(kAssignTargetIndex) == GetRealNode(inputs.at(kAssignSourceIndex)))) { + iter = exec_order.erase(iter); + } else { + ++iter; + } + } + + // Count parameter write times by check all assign nodes. + auto param_write_times = CountParameterAssigns(exec_order); + + // Erase redundant assigns. + for (auto iter = exec_order.begin(); iter != exec_order.end();) { + auto &node = *iter; + // We only try to erase argument link assign nodes, + // other assign nodes are skipped. + if (IsLinkAssign(node)) { + auto &target = node->inputs().at(kAssignTargetIndex); + MS_EXCEPTION_IF_NULL(target); + auto para = param_write_times.find(target); + if (para != param_write_times.end() && para->second == 1) { + // If target only write once, replace target with source and erase assign node. + auto &source = node->inputs().at(kAssignSourceIndex); + auto kg = target->func_graph()->cast(); + MS_EXCEPTION_IF_NULL(kg); + kg->ReplaceNode(NOT_NULL(target), NOT_NULL(source)); + iter = exec_order.erase(iter); + continue; + } + } + // Go next node. + ++iter; + } + // Set new execution order with redundant assign removed. + graph_->set_execution_order(std::move(exec_order)); + } + + // Count parameter write times by check all assign nodes. + std::map CountParameterAssigns(const std::vector &all_nodes) { + // Find all graph input parameters. + std::map param_write_times; + const auto &all_graphs = context_.visited_graphs(); + for (const auto &graph : all_graphs) { + for (auto &input : graph->inputs()) { + if (input->isa()) { + param_write_times.emplace(input, 0); + } + } + } + // Search all nodes for parameter write assigns. + for (auto &node : all_nodes) { + if (!IsPrimitiveCNode(node, prim::kPrimAssign)) { + continue; + } + auto &target = node->inputs().at(kAssignTargetIndex); + MS_EXCEPTION_IF_NULL(target); + auto iter = param_write_times.find(target); + if (iter != param_write_times.end()) { + // Found a parameter writer, count it. + ++(iter->second); + } + } + return param_write_times; + } + + // Check if a node is an assign for argument link. + bool IsLinkAssign(const AnfNodePtr &node) { + auto cnode = dyn_cast(node); + if (cnode == nullptr) { + return false; + } + auto prim = GetValueNode(cnode->inputs().at(0)); + if (!IsPrimitiveEquals(prim, prim::kPrimAssign)) { + return false; + } + return prim->GetAttr(LINK) == prim::kValueOne; + } + + // Erase LabelGoto and LabelSet + void EraseLabel() { + // Find used labels (as jump target). + std::set label_used; + auto exec_order = graph_->execution_order(); + for (auto iter = exec_order.begin(); iter != exec_order.end();) { + auto &node = *iter; + if (IsPrimitiveCNode(node, prim::kPrimLabelSwitch)) { + auto labels = AnfAlgo::GetNodeAttr>(node, kAttrLabelSwitchList); + for (auto label : labels) { + label_used.insert(label); + } + } else if (IsPrimitiveCNode(node, prim::kPrimLabelGoto)) { + auto label = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); + auto next = std::next(iter); + if (next != exec_order.end() && IsPrimitiveCNode(*next, prim::kPrimLabelSet)) { + // The LabelGoto that jump to next node can be removed. + auto next_label = AnfAlgo::GetNodeAttr(*next, kAttrLabelIndex); + if (next_label == label) { + iter = exec_order.erase(iter); + continue; + } + } + label_used.insert(label); + } + ++iter; + } + // Erase unused LabelSet nodes. + for (auto iter = exec_order.begin(); iter != exec_order.end();) { + auto &node = *iter; + if (IsPrimitiveCNode(node, prim::kPrimLabelSet)) { + auto label = AnfAlgo::GetNodeAttr(node, kAttrLabelIndex); + if (label_used.find(label) == label_used.end()) { + iter = exec_order.erase(iter); + continue; + } + } + ++iter; + } + graph_->set_execution_order(std::move(exec_order)); + } + + Context &context_; + const KernelGraphPtr graph_; +}; + +} // namespace + +void AscendAutoMonad::Run() { + MS_LOG(DEBUG) << "Ascend auto-monad start."; + AscendAutoMonadContext context(kernel_graph_.get()); + AscendAutoMonadConverter converter(&context, kernel_graph_.get()); + converter.Run(); + kernel_graph_->set_label_num(context.CurrentLabel()); + MS_LOG(DEBUG) << "Ascend auto-monad finish."; + DumpGraphForDebug(kernel_graph_); +} + +void AscendAutoMonad::GenerateExecuteOrder() { + MS_LOG(DEBUG) << "Ascend generate execute order start."; + ExecuteOrderGenerator::Context context; + ExecuteOrderGenerator generator(context, kernel_graph_.get()); + generator.Run(); + MS_LOG(DEBUG) << "Ascend generate execute order finish."; + DumpExecuteOrder(kernel_graph_); +} +} // namespace session +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_auto_monad.h b/mindspore/ccsrc/backend/session/ascend_auto_monad.h new file mode 100644 index 0000000000..3238852ce8 --- /dev/null +++ b/mindspore/ccsrc/backend/session/ascend_auto_monad.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_AUTO_MONAD_H +#define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_AUTO_MONAD_H + +#include "backend/session/kernel_graph.h" + +namespace mindspore { +namespace session { +// +// AscendAutoMonad handle control flow with auto-monad for Ascend backend. +// +class AscendAutoMonad { + public: + explicit AscendAutoMonad(NotNull kg) : kernel_graph_(kg) {} + ~AscendAutoMonad() = default; + + // Handle control flow calls by auto-monad. + void Run(); + + // Generate execute order by join sub graphs. + void GenerateExecuteOrder(); + + private: + NotNull kernel_graph_; +}; +} // namespace session +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_AUTO_MONAD_H diff --git a/mindspore/ccsrc/backend/session/ascend_control_parser.cc b/mindspore/ccsrc/backend/session/ascend_control_parser.cc index 4a7bc450be..96984a39c6 100644 --- a/mindspore/ccsrc/backend/session/ascend_control_parser.cc +++ b/mindspore/ccsrc/backend/session/ascend_control_parser.cc @@ -796,7 +796,7 @@ std::vector AscendControlParser::RecurseGraph(NotNull if (cnodes.rbegin() != cnodes.rend() && *cnodes.rbegin() == end_label_goto) { cnodes.pop_back(); } - AnfAlgo::ReorderExecList(NOT_NULL(&cnodes)); + AnfAlgo::ReorderOptimizerExecList(NOT_NULL(&cnodes)); if (end_label_goto != nullptr) { cnodes.push_back(end_label_goto); } diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 700acb3d97..cd2ecd784b 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -61,6 +61,7 @@ #include "backend/optimizer/graph_kernel/value_graph_binder.h" #include "backend/optimizer/graph_kernel/add_atomic_clean.h" #include "backend/optimizer/pass/getitem_tuple.h" +#include "backend/session/ascend_auto_monad.h" #include "debug/data_dump/e2e_dump_util.h" #include "debug/anf_ir_dump.h" #include "debug/dump_proto.h" @@ -116,6 +117,12 @@ void DumpGraphExeOrder(const std::vector &execution_order, const std:: buf << "================== execution order ==================\n"; } +// Handle control flow by auto-monad. +void HandleControlFlow(NotNull graph) { + AscendAutoMonad auto_monad(graph); + auto_monad.Run(); +} + void SetStreamDistinctionLabel(const KernelGraphPtr &graph, uint32_t label, bool is_override) { MS_EXCEPTION_IF_NULL(graph); if (is_override || graph->stream_distinction_label() == kInvalidDistincLabel) { @@ -361,16 +368,14 @@ GraphId AscendSession::CompileGraphImpl(NotNull func_graph) { InitRuntimeResource(); return root_graph->graph_id(); } - // create parameter for multiple branch - std::set memo; - CreateMultiBranchOutput(NOT_NULL(root_graph), NOT_NULL(&memo)); - memo.clear(); - // insert goto labels and label_sets - LinkChildGraphs(NOT_NULL(root_graph)); - // replace labelgoto with labelswitch in subgraph called multiple times - MultiCallGraphOptimize(NOT_NULL(root_graph)); + + // Handle control flow by auto-monad. + HandleControlFlow(NOT_NULL(root_graph)); + // resource initialize InitRuntimeResource(); + + std::set memo; IrFusionPass(NOT_NULL(root_graph), NOT_NULL(&memo)); memo.clear(); SelectKernel(NOT_NULL(root_graph)); @@ -493,10 +498,6 @@ void AscendSession::BuildGraphImpl(GraphId graph_id) { #if ENABLE_CPU && ENABLE_D InitPsWorker(graph); #endif - // Reorder optimizer order - auto execution_order = graph->execution_order(); - Reorder(&execution_order); - graph->set_execution_order(execution_order); // Assign streams for control sink and hccl and so on AssignStream(NOT_NULL(graph)); @@ -1267,7 +1268,8 @@ void AscendSession::SyncDataToExtraParams(NotNull graph, NotNull } void AscendSession::RootGraphExecutorValidate(NotNull graph) { - AscendControlParser::ExecutorValidate(graph); + AscendAutoMonad auto_monad(graph); + auto_monad.GenerateExecuteOrder(); } void AscendSession::CreateMultiBranchOutput(NotNull graph, NotNull *> memo) { diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc index f313d9746a..2663022cd9 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.cc +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -52,6 +52,9 @@ ParameterPtr CPUSession::CreateNewParameterFromParameter(const AnfNodePtr &anf, return new_parameter; } +// Remove after PS feature finish adapting push/pull in auto_monad. +void CPUSession::Reorder(std::vector *node_list) { AnfAlgo::ReorderPosteriorExecList(NOT_NULL(node_list)); } + void CPUSession::Optimize(const std::shared_ptr &kernel_graph) { auto optimizer = std::make_shared(); auto pm = std::make_shared(); @@ -81,14 +84,17 @@ GraphId CPUSession::CompileGraphImpl(const AnfNodePtrList &lst, const AnfNodePtr #endif MS_LOG(INFO) << "Build kernel"; BuildKernel(graph.get()); - // Set graph execution order before memory alloc, ensure that memory alloc is according to the reorder graph + + // Remove reorder after PS feature finish adapting push/pull in auto_monad. auto execution_order = graph->execution_order(); Reorder(&execution_order); graph->set_execution_order(execution_order); + // runtime init if (!runtime_.Init()) { MS_LOG(EXCEPTION) << "Kernel runtime init error."; } + MS_LOG(INFO) << "Assign kernel address"; runtime_.AssignKernelAddress(graph.get()); return graph_id; @@ -186,7 +192,7 @@ void CPUSession::RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, auto kernel_graph = run_op_graphs_[graph_info]; MS_EXCEPTION_IF_NULL(kernel_graph); - // Set graph execution order before memory alloc, ensure that memory alloc is according to the reorder graph + // Remove reorder after PS feature finish adapting push/pull in auto_monad. auto execution_order = kernel_graph->execution_order(); Reorder(&execution_order); kernel_graph->set_execution_order(execution_order); diff --git a/mindspore/ccsrc/backend/session/cpu_session.h b/mindspore/ccsrc/backend/session/cpu_session.h index e121b84f7a..559d7aaa76 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.h +++ b/mindspore/ccsrc/backend/session/cpu_session.h @@ -46,6 +46,7 @@ class CPUSession : public SessionBasic { VectorRef *outputs, const std::vector &tensors_mask) override; private: + void Reorder(std::vector *node_list); void SetKernelInfo(const KernelGraph *kernel_graph); void BuildKernel(const KernelGraph *kernel_graph); void SetOutputFlags(const VectorRef &base_ref, std::vector *outputs_tensors); diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 29dd67e94a..8564de2ec5 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -28,6 +28,7 @@ #include "backend/optimizer/gpu/batch_norm_relu_fusion.h" #include "backend/optimizer/gpu/batch_norm_relu_grad_fusion.h" #include "backend/optimizer/gpu/batch_norm_add_relu_fusion.h" +#include "backend/optimizer/gpu/post_batch_norm_add_relu_fusion.h" #include "backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.h" #include "backend/optimizer/gpu/combine_momentum_fusion.h" #include "backend/optimizer/gpu/combine_cast_fusion.h" @@ -57,6 +58,7 @@ #include "backend/optimizer/graph_kernel/value_graph_binder.h" #include "backend/optimizer/graph_kernel/parallel_fusion.h" #include "backend/optimizer/graph_kernel/optimize_assign.h" +#include "backend/optimizer/graph_kernel/split_assign.h" #include "backend/optimizer/pass/communication_op_fusion.h" #include "backend/optimizer/pass/getitem_tuple.h" #include "common/trans.h" @@ -151,6 +153,7 @@ void GPUSession::HardwareOptimize(const std::shared_ptr &kernel_gra pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); @@ -185,6 +188,7 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr &kernel_ auto optimizer = std::make_shared(); auto pm = std::make_shared("graph_kernel_pm"); std::vector duplicated_ops = {prim::kPrimReshape, prim::kPrimExpandDims, prim::kPrimCast}; + pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); // Make more fusion opportunity. pm->AddPass(std::make_shared()); pm->AddPass(std::make_shared()); @@ -245,6 +249,30 @@ void GPUSession::RunOpClearMemory(KernelGraph *kernel_graph) const { runtime_instance->RunOpClearMemory(kernel_graph); } +namespace { +constexpr auto kAssignInputSize = 3; +constexpr auto kAssignUpdateIndex = 1; +bool UpdatedByAssign(const KernelGraphPtr &kernel_graph, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto manager = kernel_graph->manager(); + if (manager == nullptr) { + return false; + } + auto &node_users = manager->node_users(); + auto iter = node_users.find(node); + if (iter == node_users.end()) { + return false; + } + auto &users = iter->second; + return std::any_of(users.begin(), users.end(), [](const std::pair &user) { + MS_EXCEPTION_IF_NULL(user.first); + auto output_cnode = user.first->cast(); + return output_cnode != nullptr && IsPrimitiveCNode(output_cnode, prim::kPrimAssign) && + user.second == kAssignUpdateIndex && output_cnode->inputs().size() > kAssignInputSize; + }); +} +} // namespace + void GPUSession::LoadInputData(const std::shared_ptr &kernel_graph, const std::vector &inputs_const) const { std::vector inputs(inputs_const); @@ -285,7 +313,7 @@ void GPUSession::LoadInputData(const std::shared_ptr &kernel_graph, } } if (need_sync) { - if (AnfAlgo::IsParameterWeight(input_node->cast()) || + if (AnfAlgo::IsParameterWeight(input_node->cast()) || UpdatedByAssign(kernel_graph, input_node) || ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { tensor->set_device_address(device_address); } @@ -365,10 +393,6 @@ GraphId GPUSession::CompileGraphImpl(KernelGraphPtr graph) { } // Build kernel if node is cnode BuildKernel(graph); - // Set graph execution order before memory alloc, ensure that memory alloc is according to the reorder graph - auto execution_order = graph->execution_order(); - Reorder(&execution_order); - graph->set_execution_order(execution_order); // Get summary nodes. SetSummaryNodes(graph.get()); // Dump .pb graph after graph optimization diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index d0d1f1f331..c05ebc33b2 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -383,7 +383,8 @@ void ReSetParameterValueNodeFormatAndType(const AnfNodePtr &node, const std::str void KernelGraph::ResetInFormat(const AnfNodePtr &node, const std::string &format) const { MS_EXCEPTION_IF_NULL(node); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); i++) { + size_t input_num = AnfAlgo::GetInputTensorNum(node); + for (size_t i = 0; i < input_num; i++) { auto in_node = AnfAlgo::GetInputNode(node->cast(), i); MS_EXCEPTION_IF_NULL(in_node); if ((in_node->isa() || in_node->isa()) && @@ -438,7 +439,8 @@ void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) { } auto anf_cnode = anf_node->cast(); MS_EXCEPTION_IF_NULL(anf_cnode); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_cnode); ++i) { + size_t input_num = AnfAlgo::GetInputTensorNum(anf_cnode); + for (size_t i = 0; i < input_num; ++i) { auto input_node = anf_cnode->input(i + 1); MS_EXCEPTION_IF_NULL(input_node); if (IsValueNode(input_node)) { @@ -488,7 +490,8 @@ void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const { std::vector feature_map_input_indexs; #endif kernel_info->set_feature_map_flag(false); - for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) { + size_t input_num = AnfAlgo::GetInputTensorNum(node); + for (size_t index = 0; index < input_num; ++index) { if (AnfAlgo::IsFeatureMapInput(node, index)) { kernel_info->set_feature_map_flag(true); feature_map_input_indexs.push_back(index); @@ -647,7 +650,8 @@ AnfNodePtr KernelGraph::TransCNodeTuple(const CNodePtr &node) { std::vector types; std::vector> shapes; std::vector make_tuple_inputs_list = {mindspore::NewValueNode(prim::kPrimMakeTuple)}; - for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(node); ++tuple_out_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(node); + for (size_t tuple_out_index = 0; tuple_out_index < output_num; ++tuple_out_index) { make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(node, tuple_out_index)); types.push_back(AnfAlgo::GetOutputInferDataType(node, tuple_out_index)); shapes.emplace_back(AnfAlgo::GetOutputInferShape(node, tuple_out_index)); @@ -886,7 +890,6 @@ void KernelGraph::UpdateNodeEdgeList(std::queue *seed_nodes) { node_output_edges_.clear(); node_input_num_.clear(); node_input_edges_.clear(); - std::vector control_depends; std::unordered_set visited_nodes; std::queue que; que.push(get_return()); @@ -898,24 +901,18 @@ void KernelGraph::UpdateNodeEdgeList(std::queue *seed_nodes) { seed_nodes->push(node); continue; } - if (!node->isa()) { + auto cnode = dyn_cast(node); + if (cnode == nullptr) { continue; } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - // handle data links - for (const auto &input : cnode->inputs()) { - size_t depend_edge_num = 1; - // handle control depend,all inputs of control depend has no depend edge - if (HandleControlDependNode(input, &que, &visited_nodes)) { - control_depends.push_back(input); - depend_edge_num = 0; - } + auto &inputs = cnode->inputs(); + // We push inputs from right to left, so that them can be evaluated from left to right. + for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { + auto &input = *iter; PushNoVisitedNode(input, &que, &visited_nodes); - AddDependEdge(node, input, depend_edge_num); + AddDependEdge(node, input, 1); } } - UpdateControlDependRelations(control_depends); } void KernelGraph::AddValueNodeToGraph(const ValueNodePtr &value_node) { (void)graph_value_nodes_.insert(value_node); } @@ -984,7 +981,6 @@ void KernelGraph::ReplaceNode(NotNull old_anf_node, NotNullset_input(i, new_anf_node); } } - ReplaceGraphInput(old_anf_node, new_anf_node); } // update front to backend map FrontBackendlMapUpdate(old_anf_node, new_anf_node); @@ -1041,6 +1037,9 @@ std::vector KernelGraph::FindNodeByPrimitive(const std::vector &order) { execution_order_ = order; } + void set_execution_order(std::vector &&order) { execution_order_ = std::move(order); } const std::vector &execution_order() const { return execution_order_; } void SetExecOrderByDefault(); uint32_t graph_id() const { return graph_id_; } @@ -273,6 +274,9 @@ class KernelGraph : public FuncGraph { } // end of handle graph dependency + uint32_t label_num() const { return label_num_; } + void set_label_num(uint32_t num) { label_num_ = num; } + private: // remove value node form graph bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node); @@ -358,6 +362,10 @@ class KernelGraph : public FuncGraph { bool first_step_{true}; bool has_optimizer_{false}; bool is_dynamic_shape_{false}; + + // Number of labels. This is also the 'batch_num' for DavinciModel, + // It should be 1 if no labels used for control flow. + uint32_t label_num_ = 1; }; } // namespace session using KernelGraphPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index c457de1dcd..d4f8498a21 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -164,6 +164,12 @@ tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_o return tensor; } +static bool IsPynativeMode() { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + return ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode; +} + BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph, const std::vector &input_tensors, std::map *tensor_to_node) { @@ -172,17 +178,18 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(tensor_to_node); - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << node_output_pair.second << "]"; + if (HasAbstractMonad(node)) { + return std::make_shared(int64_t(0), kBool); + } // if node is a value node, no need sync addr from device to host if (node->isa()) { auto value_node = node->cast(); MS_EXCEPTION_IF_NULL(value_node); return value_node->value(); } - if (!AnfAlgo::OutputAddrExist(node, output_index) || - (CheckIfNeedCreateOutputTensor(node) && ms_context->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode)) { + bool output_addr_exist = AnfAlgo::OutputAddrExist(node, output_index); + if (!output_addr_exist || (CheckIfNeedCreateOutputTensor(node) && !IsPynativeMode())) { if (node->isa()) { for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) { if (input_idx >= input_tensors.size()) { @@ -192,7 +199,9 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, return input_tensors[input_idx]; } } - MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr"; + if (!output_addr_exist) { + MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr"; + } } } auto tensor = CreateCNodeOutputTensor(node_output_pair, graph); @@ -688,7 +697,7 @@ std::vector SessionBasic::CreateParameterFromTuple(const AnfNodePtr } // If a cnode is a call, it's input0 is a cnode too, so it doesn't have primitive if (!pre_graph_out.empty() && !AnfAlgo::IsRealKernel(node)) { - pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem}); + pre_graph_out = AnfAlgo::GetAllOutput(node, {prim::kPrimTupleGetItem, prim::kPrimUpdateState}); } auto valid_inputs = graph->MutableValidInputs(); MS_EXCEPTION_IF_NULL(valid_inputs); @@ -796,7 +805,6 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, MS_EXCEPTION_IF_NULL(cnode_inputs); auto origin_inputs = cnode->inputs(); bool optimize_depend = IsPrimitiveCNode(cnode, prim::kPrimDepend) && origin_inputs.size() >= 3; - bool optimize_control_depend = IsPrimitiveCNode(cnode, prim::kPrimControlDepend) && origin_inputs.size() == 3; // if has multiple depends,only select first depend as parameter for (size_t input_idx = 1; input_idx < origin_inputs.size(); input_idx++) { auto anf = origin_inputs[input_idx]; @@ -827,14 +835,17 @@ void SessionBasic::GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, (*other_graph_cnode)[anf] = new_parameter; } continue; - } else if (optimize_control_depend || IsPrimitiveCNode(anf, prim::kPrimControlDepend)) { - cnode_inputs->push_back(NewValueNode(MakeValue(SizeToLong(input_idx)))); } else { // the input node is a cnode from other graph auto parameter_from_cnode = CreateNewParameterFromCNode(anf, graph); if (parameter_from_cnode == nullptr) { parameter_from_cnode = NewValueNode(MakeValue(SizeToLong(input_idx))); } + if (parameter_from_cnode->isa() && IsPrimitiveCNode(anf, prim::kPrimLoad)) { + auto para = parameter_from_cnode->cast(); + auto load_cnode = anf->cast(); + para->set_name(load_cnode->input(kFirstDataInputIndex)->fullname_with_scope()); + } cnode_inputs->push_back(parameter_from_cnode); (*other_graph_cnode)[anf] = parameter_from_cnode; } @@ -904,7 +915,10 @@ std::vector SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno auto partial_node = node->cast(); MS_EXCEPTION_IF_NULL(partial_node); std::vector partial_inputs = partial_node->inputs(); - partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(kFirstDataInputIndex))); + // Put all call args at the end of partial inputs. + for (size_t i = kFirstDataInputIndex; i < cnode->size(); ++i) { + partial_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(i))); + } auto new_partial = graph->NewCNode(partial_inputs); switch_inputs.emplace_back(new_partial); } @@ -1213,6 +1227,9 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con MS_EXCEPTION_IF_NULL(new_cnode); new_cnode->set_abstract(cnode->abstract()); new_cnode->set_scope(cnode->scope()); + if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) { + new_cnode->set_fullname_with_scope(cnode->input(kFirstDataInputIndex)->fullname_with_scope()); + } // record map relations between anf from ME and new anf node used in backend graph->FrontBackendlMapAdd(node, new_cnode); } @@ -1240,11 +1257,11 @@ KernelGraphPtr SessionBasic::ConstructKernelGraph(const AnfNodePtrList &lst, con auto node_ptr = input_node->cast(); MS_EXCEPTION_IF_NULL(node_ptr); if (!IsUsedByRealKernel(manager, input_node)) { - node_ptr->set_used_by_real_kernel(); + node_ptr->set_used_by_real_kernel(false); } auto shape = node_ptr->Shape(); if (IsShapeDynamic(shape->cast())) { - node_ptr->set_used_by_dynamic_kernel(); + node_ptr->set_used_by_dynamic_kernel(true); } } } @@ -1376,6 +1393,9 @@ void SessionBasic::GetOpInputTensors(const CNodePtr &cnode, MS_EXCEPTION_IF_NULL(real_input); tensor::TensorPtr tensor = nullptr; if (real_input->isa()) { + if (HasAbstractMonad(real_input)) { + continue; + } tensor = GetValueNodeOutputTensor(real_input, kernel_with_index.second); } else if (real_input->isa()) { tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs); @@ -1408,6 +1428,8 @@ bool SessionBasic::CreateCNodeOfKernelGraph(const AnfNodePtr &node, KernelGraph std::string fullname; if (cnode->input(kAnfPrimitiveIndex)->isa()) { fullname = cnode->input(kAnfPrimitiveIndex)->fullname_with_scope(); + } else if (IsPrimitiveCNode(cnode, prim::kPrimLoad)) { + fullname = cnode->input(kFirstDataInputIndex)->fullname_with_scope(); } else { fullname = cnode->fullname_with_scope(); } @@ -1483,11 +1505,11 @@ std::shared_ptr SessionBasic::ConstructKernelGraph(const FuncGraphP auto node_ptr = input_node->cast(); MS_EXCEPTION_IF_NULL(node_ptr); if (!IsUsedByRealKernel(manager, input_node)) { - node_ptr->set_used_by_real_kernel(); + node_ptr->set_used_by_real_kernel(false); } auto shape = node_ptr->Shape(); if (IsShapeDynamic(shape->cast())) { - node_ptr->set_used_by_dynamic_kernel(); + node_ptr->set_used_by_dynamic_kernel(true); } } } @@ -1733,8 +1755,6 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) { summary_callback_ = callback; } -void SessionBasic::Reorder(std::vector *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); } - void SessionBasic::RunInfer(NotNull func_graph, const std::vector &inputs) { auto node_list = TopoSort(func_graph->get_return()); size_t tensor_index = 0; @@ -1742,7 +1762,8 @@ void SessionBasic::RunInfer(NotNull func_graph, const std::vector< MS_EXCEPTION_IF_NULL(node); if (node->isa()) { AbstractBasePtrList input_abstracts; - for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) { + size_t input_num = AnfAlgo::GetInputTensorNum(node); + for (size_t index = 0; index < input_num; ++index) { auto input_node = AnfAlgo::GetInputNode(node->cast(), index); MS_EXCEPTION_IF_NULL(input_node); auto abstract = input_node->abstract(); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 4f16acd2ae..96b274b1d2 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -190,7 +190,6 @@ class SessionBasic : public std::enable_shared_from_this { void UpdateOutputs(const std::shared_ptr &kernel_graph, VectorRef *const outputs, const std::vector &input_tensors) const; void UpdateOutputAbstract(const std::shared_ptr &kernel_graph, OpRunInfo *op_run_info) const; - void Reorder(std::vector *node_list); void Summary(KernelGraph *graph); // create graph output for RunOp void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr &graph); diff --git a/mindspore/ccsrc/debug/anf_ir_dump.cc b/mindspore/ccsrc/debug/anf_ir_dump.cc index a8a09931f8..4473c6da88 100644 --- a/mindspore/ccsrc/debug/anf_ir_dump.cc +++ b/mindspore/ccsrc/debug/anf_ir_dump.cc @@ -81,7 +81,7 @@ void PrintNodeInputType(std::ostringstream &buffer, const AnfNodePtr &nd) { return; } - std::vector inputs = SuccIncoming(nd); + const auto &inputs = GetInputs(nd); size_t len = inputs.size(); if (len > 1) { // skip inputs[0] which is Primitive value node @@ -137,7 +137,8 @@ void DumpKernelInfo(const CNodePtr &node, const std::shared_ptr } gsub->buffer << " : ("; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { + size_t input_num = AnfAlgo::GetInputTensorNum(node); + for (size_t i = 0; i < input_num; ++i) { if (i != 0) { gsub->buffer << ", "; } @@ -147,7 +148,8 @@ void DumpKernelInfo(const CNodePtr &node, const std::shared_ptr PrintKernelFormatAndType(gsub->buffer, format, type, shape); } gsub->buffer << ") -> ("; - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(node); ++i) { + size_t output_num = AnfAlgo::GetOutputTensorNum(node); + for (size_t i = 0; i < output_num; ++i) { if (i != 0) { gsub->buffer << ", "; } @@ -238,7 +240,7 @@ void DumpOperands(const AnfNodePtr &nd, OrderedMap *para_ma } gsub->buffer << "("; - std::vector inputs = SuccIncoming(nd); + const auto &inputs = GetInputs(nd); size_t len = inputs.size(); if (len > 1) { // skip inputs[0] which is Primitive valuenode diff --git a/mindspore/ccsrc/debug/anf_ir_utils.cc b/mindspore/ccsrc/debug/anf_ir_utils.cc index d2f451b773..e131c605d9 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.cc +++ b/mindspore/ccsrc/debug/anf_ir_utils.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include "ir/graph_utils.h" #include "utils/symbolic.h" #include "ir/meta_func_graph.h" @@ -447,7 +448,8 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr &func_graph, const An } oss << "%" << iter->second; } else if (node->isa()) { - oss << "%para" << GetParamIndex(func_graph, node, check_integrity_); + // Parameter maybe a free variable, so check it in its own funcgraph. + oss << "%para" << GetParamIndex(node->func_graph(), node, check_integrity_); } else if (IsValueNode(node)) { FuncGraphPtr fg = GetValueNode(node); oss << fg->type_name() << "::fg_" << fg->debug_info()->get_id(); @@ -594,11 +596,43 @@ void AnfExporter::OutputCNodes(std::ofstream &ofs, const std::vector ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#" << label_manage::Label(cnode->debug_info()) << "\n"; } else { - ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "\n"; + ofs << trace::GetDebugInfo(cnode->debug_info(), " # ", kSourceLineTipDiscard) << "#" << cnode->ToString() + << "\n"; } } } +void AnfExporter::OutputOrderList(std::ofstream &ofs, const FuncGraphPtr &func_graph) { + auto &order_list = func_graph->order_list(); + if (order_list.empty()) { + return; + } + constexpr int width = 4; + ofs << "# order:\n"; + int i = 1; + auto &isolate_nodes = func_graph->isolate_nodes(); + for (auto &node : order_list) { + bool is_isolate = (isolate_nodes.find(node) != isolate_nodes.end()); + const std::string isolate_str = (is_isolate ? " # isolate" : ""); + ofs << '#' << std::setw(width) << i << ": " << node->DebugString() << isolate_str << '\n'; + ++i; + } +} + +void AnfExporter::OutputIsolateNodes(std::ofstream &ofs, const FuncGraphPtr &func_graph) { + auto &isolate_nodes = func_graph->isolate_nodes(); + if (isolate_nodes.empty()) { + return; + } + constexpr int width = 4; + ofs << "# isolate nodes:\n"; + int i = 1; + for (auto &node : isolate_nodes) { + ofs << '#' << std::setw(width) << i << ": " << node->DebugString() << '\n'; + ++i; + } +} + void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return; @@ -634,6 +668,9 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &fun OutputCNodes(ofs, nodes, func_graph); ofs << "}\n"; + + OutputOrderList(ofs, func_graph); + OutputIsolateNodes(ofs, func_graph); } void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) { diff --git a/mindspore/ccsrc/debug/anf_ir_utils.h b/mindspore/ccsrc/debug/anf_ir_utils.h index 359fdef57b..afce09efc6 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.h +++ b/mindspore/ccsrc/debug/anf_ir_utils.h @@ -97,6 +97,8 @@ class AnfExporter { void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node); virtual void OutputCNodes(std::ofstream &ofs, const std::vector &nodes, const FuncGraphPtr &func_graph); + void OutputOrderList(std::ofstream &ofs, const FuncGraphPtr &func_graph); + void OutputIsolateNodes(std::ofstream &ofs, const FuncGraphPtr &func_graph); int param_index; OrderedSet func_graph_set{}; diff --git a/mindspore/ccsrc/debug/data_dump/e2e_dump_util.cc b/mindspore/ccsrc/debug/data_dump/e2e_dump_util.cc index 350ea455e0..86384ba417 100644 --- a/mindspore/ccsrc/debug/data_dump/e2e_dump_util.cc +++ b/mindspore/ccsrc/debug/data_dump/e2e_dump_util.cc @@ -119,6 +119,9 @@ void E2eDumpUtil::DumpOutputImpl(const CNodePtr &node, bool trans_flag, const st GetFileKernelName(NOT_NULL(kernel_name)); auto output_size = AnfAlgo::GetOutputTensorNum(node); for (size_t j = 0; j < output_size; ++j) { + if (!AnfAlgo::OutputAddrExist(node, j)) { + continue; + } auto addr = AnfAlgo::GetOutputAddr(node, j); ShapeVector int_shapes; GetDumpIntShape(node, j, trans_flag, NOT_NULL(&int_shapes)); @@ -163,6 +166,9 @@ void E2eDumpUtil::DumpInputImpl(const CNodePtr &node, bool trans_flag, const std auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, j); auto input = kernel_with_index.first; auto index = kernel_with_index.second; + if (!AnfAlgo::OutputAddrExist(input, index)) { + continue; + } auto addr = AnfAlgo::GetOutputAddr(input, index); std::string tensor_name; diff --git a/mindspore/ccsrc/debug/debugger/debug_graph.proto b/mindspore/ccsrc/debug/debugger/debug_graph.proto index 9b9a496367..d268c4fd97 100644 --- a/mindspore/ccsrc/debug/debugger/debug_graph.proto +++ b/mindspore/ccsrc/debug/debugger/debug_graph.proto @@ -88,6 +88,10 @@ enum DataType { DT_ANYTHING = 40; // type anything DT_REFKEY = 41; // type refkey DT_REF = 42; // type ref + + // auto_monad type + DT_UMONAD = 43; + DT_IOMONAD = 44; } // Value definition for attribute value or parameter default value @@ -255,7 +259,7 @@ message ModelProto { // The parameterized graph that is evaluated to execute the model. optional GraphProto graph = 4; - // metadata info of opeartors + // metadata info of operators optional OperatorSetProto metadata_operators = 5; }; diff --git a/mindspore/ccsrc/debug/debugger/proto_exporter.cc b/mindspore/ccsrc/debug/debugger/proto_exporter.cc index 7d2c07dc5a..7bf35b3b0c 100644 --- a/mindspore/ccsrc/debug/debugger/proto_exporter.cc +++ b/mindspore/ccsrc/debug/debugger/proto_exporter.cc @@ -113,6 +113,10 @@ void DebuggerProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseSha type_proto->set_data_type(debugger::DT_STRING); } else if (type->isa()) { // Do Nothing. + } else if (type->isa()) { + type_proto->set_data_type(debugger::DT_UMONAD); + } else if (type->isa()) { + type_proto->set_data_type(debugger::DT_IOMONAD); } else { MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name(); } diff --git a/mindspore/ccsrc/debug/draw.cc b/mindspore/ccsrc/debug/draw.cc index 23deb1d5c7..89bf23a1ad 100644 --- a/mindspore/ccsrc/debug/draw.cc +++ b/mindspore/ccsrc/debug/draw.cc @@ -97,7 +97,7 @@ void DrawValueNodes(const std::vector &nodes, int dup_idx = 0; for (auto &nd : nodes) { - for (auto &t : SuccIncoming(nd)) { + for (auto &t : GetInputs(nd)) { MS_EXCEPTION_IF_NULL(t); MS_EXCEPTION_IF_NULL(nd); if (t->isa() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) { @@ -125,7 +125,7 @@ void DrawEdges(const std::vector &nodes, const std::shared_ptrset_data_type(irpb::DT_STRING); } else if (type->isa()) { // Do Nothing. + } else if (type->isa()) { + // Do Nothing. + } else if (type->isa()) { + MS_LOG(WARNING) << "The type: " << type->type_name(); } else { MS_LOG(EXCEPTION) << "Unknown type: " << type->type_name(); } @@ -218,6 +222,8 @@ void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value type_proto->set_data_type(irpb::DT_TENSOR); TypePtr elem_type = dyn_cast(val)->element(); type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type)); + } else if (val->isa()) { + value_proto->set_str_val(val->ToString()); } else { MS_LOG(WARNING) << "Unsupported type " << val->type_name(); } diff --git a/mindspore/ccsrc/frontend/operator/composite/composite.cc b/mindspore/ccsrc/frontend/operator/composite/composite.cc index 61b4bfbda1..8aabad979b 100644 --- a/mindspore/ccsrc/frontend/operator/composite/composite.cc +++ b/mindspore/ccsrc/frontend/operator/composite/composite.cc @@ -146,7 +146,7 @@ AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const Anf (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs), [](const std::pair &item) { return item.first; }); - return func_graph->NewCNode(inputs); + return func_graph->NewCNodeInOrder(inputs); } AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, @@ -180,12 +180,12 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraph (void)std::transform( arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), [&func_graph, i](const std::pair &item) { - return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); + return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); }); - inputs.push_back(func_graph->NewCNode(inputs2)); + inputs.push_back(func_graph->NewCNodeInOrder(inputs2)); } - return func_graph->NewCNode(inputs); + return func_graph->NewCNodeInOrder(inputs); } AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, @@ -218,12 +218,12 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGrap (void)std::transform( arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), [&func_graph, &i](std::pair item) { - return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); + return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); }); - inputs.push_back(func_graph->NewCNode(inputs2)); + inputs.push_back(func_graph->NewCNodeInOrder(inputs2)); } - return func_graph->NewCNode(inputs); + return func_graph->NewCNodeInOrder(inputs); } AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, @@ -248,13 +248,13 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGrap int64_t j = 0; for (auto item : arg_map) { - inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); + inputs2.push_back(func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); j++; } - inputs.push_back(func_graph->NewCNode(inputs2)); + inputs.push_back(func_graph->NewCNodeInOrder(inputs2)); } - return func_graph->NewCNode(inputs); + return func_graph->NewCNodeInOrder(inputs); } AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { @@ -318,8 +318,8 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL for (auto &item : args_spec_list) { if (!IsSubType(item.second, type_tensor)) { TypePtr type_tensor_ele = std::make_shared(item.second); - ret.push_back( - std::make_pair(func_graph->NewCNode({NewValueNode(prim::kPrimScalarToArray), item.first}), type_tensor_ele)); + ret.push_back(std::make_pair(func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimScalarToArray), item.first}), + type_tensor_ele)); } else { ret.push_back(std::make_pair(item.first, item.second)); } @@ -414,14 +414,14 @@ FuncGraphPtr Tail::GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr & if (tail_type_ == kGradAll) { MS_EXCEPTION_IF_NULL((*sequeue)[i]); if ((*sequeue)[i]->isa()) { - elems.push_back(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); + elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); } } else { - elems.push_back(ret->NewCNode({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); + elems.push_back(ret->NewCNodeInOrder({NewValueNode(op), ptrTup, NewValueNode(SizeToLong(i))})); } } - ret->set_output(ret->NewCNode(elems)); + ret->set_output(ret->NewCNodeInOrder(elems)); return ret; } @@ -458,7 +458,7 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg } // make fprob first result, maketuple's forward result. - AnfNodePtr out = fg->NewCNode(params); + AnfNodePtr out = fg->NewCNodeInOrder(params); // make fprob second result, maketuple's backward function. FuncGraphPtr b = std::make_shared(); @@ -472,14 +472,14 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg grads.push_back(NewValueNode(prim::kPrimMakeTuple)); grads.push_back(NewValueNode(newenv)); for (int64_t i = 0; i < tuple_size; ++i) { - grads.push_back(b->NewCNode({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)})); + grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)})); } b->set_flag(FUNC_GRAPH_FLAG_CORE, true); - b->set_output(b->NewCNode(grads)); + b->set_output(b->NewCNodeInOrder(grads)); fg->set_flag(FUNC_GRAPH_FLAG_CORE, true); - fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)})); + fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)})); (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple)); return fg; } @@ -499,7 +499,7 @@ FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args } // make fprob first result, maketuple's forward result. - AnfNodePtr out = fg->NewCNode(params); + AnfNodePtr out = fg->NewCNodeInOrder(params); // make fprob second result, maketuple's backward function. FuncGraphPtr b = std::make_shared(); @@ -513,14 +513,14 @@ FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args grads.push_back(NewValueNode(prim::kPrimMakeTuple)); grads.push_back(NewValueNode(newenv)); for (int64_t i = 0; i < list_size; ++i) { - grads.push_back(b->NewCNode({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)})); + grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)})); } b->set_flag(FUNC_GRAPH_FLAG_CORE, true); - b->set_output(b->NewCNode(grads)); + b->set_output(b->NewCNodeInOrder(grads)); fg->set_flag(FUNC_GRAPH_FLAG_CORE, true); - fg->set_output(fg->NewCNode({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)})); + fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)})); (void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeList)); return fg; } @@ -545,7 +545,7 @@ FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weigh if (weights != nullptr) { weights_node = weights; } else if (!weight_args.empty()) { - weights_node = k_child->NewCNode(weight_args); + weights_node = k_child->NewCNodeInOrder(weight_args); } std::vector inputs; @@ -553,11 +553,11 @@ FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weigh for (size_t i = 0; i < forward_graph_params.size(); ++i) { inputs.push_back(k_child->add_parameter()); } - auto k_app = k_child->NewCNode(inputs); + auto k_app = k_child->NewCNodeInOrder(inputs); auto tuple_get_item = NewValueNode(prim::kPrimTupleGetItem); - auto f_app = k_child->NewCNode({tuple_get_item, k_app, NewValueNode(static_cast(0))}); - auto bprop = k_child->NewCNode({tuple_get_item, k_app, NewValueNode(static_cast(1))}); + auto f_app = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast(0))}); + auto bprop = k_child->NewCNodeInOrder({tuple_get_item, k_app, NewValueNode(static_cast(1))}); GradByParameter(k_child, f_app, bprop, weights_node); return k_child; @@ -573,31 +573,31 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt bprop_arg = k_child->add_parameter(); } else { auto ones_like = prim::GetPythonOps("ones_like"); - bprop_arg = k_child->NewCNode({NewValueNode(ones_like), f_app}); + bprop_arg = k_child->NewCNodeInOrder({NewValueNode(ones_like), f_app}); } - AnfNodePtr b_app = k_child->NewCNode({bprop, bprop_arg}); + AnfNodePtr b_app = k_child->NewCNodeInOrder({bprop, bprop_arg}); CNodePtr fv_bprop = nullptr; if (get_by_list_) { // python code: grads = hyper_map(F.partial(env_get, env), weights) AnfNodePtr env = - k_child->NewCNode({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast(0))}); + k_child->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast(0))}); AnfNodePtr partial_env_get = - k_child->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env}); + k_child->NewCNodeInOrder({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env}); MetaFuncGraphPtr hyper_map = std::make_shared(); - fv_bprop = k_child->NewCNode({NewValueNode(hyper_map), partial_env_get, weights}); + fv_bprop = k_child->NewCNodeInOrder({NewValueNode(hyper_map), partial_env_get, weights}); } CNodePtr inputs_bprop = nullptr; if (get_all_) { TailPtr tail_grad_all = std::make_shared("tail_grad_all", kGradAll); - inputs_bprop = k_child->NewCNode({NewValueNode(tail_grad_all), b_app}); + inputs_bprop = k_child->NewCNodeInOrder({NewValueNode(tail_grad_all), b_app}); } // Gradients wrt inputs and parameters if (fv_bprop != nullptr && inputs_bprop != nullptr) { - k_child->set_output(k_child->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop})); + k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop})); return; } @@ -616,7 +616,7 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt // b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), // so obtain first input grad by setting tail_type of Tail to kGradFirst. TailPtr tail_grad_first = std::make_shared("tail_grad_first", kGradFirst); - k_child->set_output(k_child->NewCNode({NewValueNode(tail_grad_first), b_app})); + k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app})); } // Generate the graph. @@ -659,7 +659,7 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp std::vector inputs; inputs.push_back(NewValueNode(prim::kPrimJ)); inputs.push_back(param_graph); - auto j = grad_fg->NewCNode(inputs); + auto j = grad_fg->NewCNodeInOrder(inputs); // df is checked in GetGrad FuncGraphPtr k_child = nullptr; { @@ -706,26 +706,27 @@ FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_lis std::vector iters; (void)std::transform(lists.begin(), lists.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(std::string("list_iter")), item}); + return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("list_iter")), item}); }); std::vector nexts; (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(std::string("next")), item}); + return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("next")), item}); }); std::vector values; (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item}); + return fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item}); }); (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(static_cast(1))}); + return fg_ptr->NewCNodeInOrder( + {NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(static_cast(1))}); }); (void)values.insert(values.begin(), fn); - AnfNodePtr cnode_graph = fg_ptr->NewCNode(values); - AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimMakeList), cnode_graph}); + AnfNodePtr cnode_graph = fg_ptr->NewCNodeInOrder(values); + AnfNodePtr resl = fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimMakeList), cnode_graph}); FuncGraphPtr fgnext_ptr = std::make_shared(); fgnext_ptr->debug_info()->set_name("body"); @@ -736,7 +737,7 @@ FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_lis MakeCond(lists, fgnext_ptr, fgcond_ptr); MakeNext(lists, fgcond_ptr, fgnext_ptr); - CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl}); + CNodePtr output_cnode = fg_ptr->NewCNodeInOrder({NewValueNode(fgcond_ptr), fn, resl}); auto inputs = output_cnode->inputs(); (void)inputs.insert(inputs.end(), iters.begin(), iters.end()); @@ -759,7 +760,7 @@ void ListMap::MakeCond(const std::vector &lists, const FuncGraphPtr std::vector hasnexts; (void)std::transform(iters.begin(), iters.end(), std::back_inserter(hasnexts), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(std::string("hasnext")), item}); + return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("hasnext")), item}); }); // cond = reduce(lambda a, b: g.apply(P.bool_and, a, b), hasnexts) @@ -767,7 +768,7 @@ void ListMap::MakeCond(const std::vector &lists, const FuncGraphPtr fgtrue_ptr->debug_info()->set_name("ftrue"); fgtrue_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); - CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNode({NewValueNode(fgnext_ptr), fn, resl}); + CNodePtr fgtrue_output_cnode = fgtrue_ptr->NewCNodeInOrder({NewValueNode(fgnext_ptr), fn, resl}); auto inputs = fgtrue_output_cnode->inputs(); (void)inputs.insert(inputs.end(), iters.begin(), iters.end()); fgtrue_output_cnode->set_inputs(inputs); @@ -778,8 +779,8 @@ void ListMap::MakeCond(const std::vector &lists, const FuncGraphPtr fgfalse_ptr->set_flag(FUNC_GRAPH_FLAG_CORE, true); fgfalse_ptr->set_output(resl); - AnfNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")), - NewValueNode(fgtrue_ptr), NewValueNode(fgfalse_ptr)}); + AnfNodePtr output_cnode = fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), NewValueNode(std::string("cond")), + NewValueNode(fgtrue_ptr), NewValueNode(fgfalse_ptr)}); fgtrue_ptr->set_output(output_cnode); } @@ -794,23 +795,24 @@ void ListMap::MakeNext(const std::vector &lists, const FuncGraphPtr std::vector nexts; (void)std::transform(iters.begin(), iters.end(), std::back_inserter(nexts), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(std::string("next")), item}); + return fg_ptr->NewCNodeInOrder({NewValueNode(std::string("next")), item}); }); std::vector values; (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(values), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, nullptr}); + return fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item, nullptr}); }); iters.clear(); (void)std::transform(nexts.begin(), nexts.end(), std::back_inserter(iters), [fg_ptr](AnfNodePtr item) { - return fg_ptr->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(static_cast(1))}); + return fg_ptr->NewCNodeInOrder( + {NewValueNode(prim::kPrimTupleGetItem), item, NewValueNode(static_cast(1))}); }); (void)values.insert(values.begin(), fn); - AnfNodePtr cnode_graph = fg_ptr->NewCNode(values); - AnfNodePtr resl = fg_ptr->NewCNode({NewValueNode(prim::kPrimListAppend), cnode_graph}); - CNodePtr output_cnode = fg_ptr->NewCNode({NewValueNode(fgcond_ptr), fn, resl}); + AnfNodePtr cnode_graph = fg_ptr->NewCNodeInOrder(values); + AnfNodePtr resl = fg_ptr->NewCNodeInOrder({NewValueNode(prim::kPrimListAppend), cnode_graph}); + CNodePtr output_cnode = fg_ptr->NewCNodeInOrder({NewValueNode(fgcond_ptr), fn, resl}); auto inputs = output_cnode->inputs(); (void)inputs.insert(inputs.end(), iters.begin(), iters.end()); @@ -853,15 +855,15 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li int64_t tuple_size = SizeToLong(a_tuple->size()); for (int64_t i = 0; i < tuple_size; ++i) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_a, NewValueNode(i)})); + elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tup_a, NewValueNode(i)})); } tuple_size = SizeToLong(b_tuple->size()); for (int64_t i = 0; i < tuple_size; ++i) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tup_b, NewValueNode(i)})); + elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tup_b, NewValueNode(i)})); } - ret->set_output(ret->NewCNode(elems)); + ret->set_output(ret->NewCNodeInOrder(elems)); return ret; } @@ -956,15 +958,15 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_ elems.push_back(NewValueNode(prim::kPrimMakeTuple)); if (step_value > 0) { for (int64_t index = start_index; index < stop_index; index = index + step_value) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)})); + elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)})); } } else { for (int64_t index = start_index; index > stop_index; index = index + step_value) { - elems.push_back(ret->NewCNode({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)})); + elems.push_back(ret->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), p_tuple, NewValueNode(index)})); } } - ret->set_output(ret->NewCNode(elems)); + ret->set_output(ret->NewCNodeInOrder(elems)); return ret; } @@ -978,7 +980,7 @@ FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &ar auto functions = ret_graph->add_parameter(); auto index = ret_graph->add_parameter(); - ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions})); + ret_graph->set_output(ret_graph->NewCNodeInOrder({NewValueNode(prim::kPrimSwitchLayer), index, functions})); return ret_graph; } diff --git a/mindspore/ccsrc/frontend/operator/composite/do_signature.cc b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc index d1326ac943..e5eeb75c7b 100644 --- a/mindspore/ccsrc/frontend/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/frontend/operator/composite/do_signature.cc @@ -194,7 +194,7 @@ AnfNodePtr DoCast(const AnfNodePtr ¶m, const TypeId &type_id, const FuncGrap MS_EXCEPTION_IF_NULL(prim_cast_class); auto dtype_node = NewValueNode(TypeIdToType(type_id)); auto cast_node = NewCNode({NewValueNode(prim_cast_class)}, graph); - return NewCNode({cast_node, param, dtype_node}, graph); + return graph->NewCNodeAfter(param, {cast_node, param, dtype_node}); } void DoAutoCast(const std::string &func_name, const std::vector &signature, @@ -274,7 +274,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func continue; } SignatureEnumRW sig = SignatureEnumRW::kRWDefault; - // If sig_size is 0 use defalut. + // If sig_size is 0 use default. if (sig_size > 0 && i < sig_size) { sig = signature[i].rw; } else if (has_var && i >= sig_size) { @@ -289,7 +289,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func auto source_element = source_tensor_type->element(); if (cast_type != nullptr && IsSubType(source_element, kFloat) && *source_element != *cast_type) { auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); - param = NewCNode({NewValueNode(cast), param, NewValueNode(cast_type)}, func_graph); + param = func_graph->NewCNodeAfter(param, {NewValueNode(cast), param, NewValueNode(cast_type)}); type = cast_type->type_id() == kNumberTypeFloat16 ? kTensorTypeFP16 : kTensorTypeFP32; } } @@ -309,7 +309,7 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func // process default ProcessDefault(func_name, args_spec_list.size(), signature, has_var, &op_inputs); DoAutoCast(func_name, signature, input_types, func_graph, &op_inputs, write_indices); - return func_graph->NewCNode(op_inputs); + return func_graph->NewCNodeInOrder(op_inputs); } } // namespace diff --git a/mindspore/ccsrc/frontend/operator/composite/map.cc b/mindspore/ccsrc/frontend/operator/composite/map.cc index 4a92634537..a2bf8b2cd2 100644 --- a/mindspore/ccsrc/frontend/operator/composite/map.cc +++ b/mindspore/ccsrc/frontend/operator/composite/map.cc @@ -44,7 +44,7 @@ AnfNodePtr Map::FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &f inputs.emplace_back(NewValueNode(fn_leaf_)); } inputs.insert(inputs.end(), args.begin(), args.end()); - return func_graph->NewCNode(inputs); + return func_graph->NewCNodeInOrder(inputs); } FuncGraphPtr Map::GenerateLeafFunc(const size_t &args_size) { @@ -98,12 +98,12 @@ AnfNodePtr Map::FullMakeList(const std::shared_ptr &type, const FuncGraphP (void)std::transform( arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), [&func_graph, i](const std::pair &item) { - return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); + return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); }); - inputs.push_back(func_graph->NewCNode(inputs2)); + inputs.push_back(func_graph->NewCNodeInOrder(inputs2)); } - return func_graph->NewCNode(inputs); + return func_graph->NewCNodeInOrder(inputs); } AnfNodePtr Map::FullMakeTuple(const std::shared_ptr &type, const FuncGraphPtr &func_graph, @@ -139,12 +139,12 @@ AnfNodePtr Map::FullMakeTuple(const std::shared_ptr &type, const FuncGrap (void)std::transform( arg_pairs.begin(), arg_pairs.end(), std::back_inserter(inputs2), [&func_graph, &i](std::pair item) { - return func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); + return func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), item.first, NewValueNode(i)}); }); - inputs.push_back(func_graph->NewCNode(inputs2)); + inputs.push_back(func_graph->NewCNodeInOrder(inputs2)); } - return func_graph->NewCNode(inputs); + return func_graph->NewCNodeInOrder(inputs); } AnfNodePtr Map::FullMakeClass(const std::shared_ptr &type, const FuncGraphPtr &func_graph, @@ -170,13 +170,13 @@ AnfNodePtr Map::FullMakeClass(const std::shared_ptr &type, const FuncGrap int64_t j = 0; for (auto item : arg_pairs) { - inputs2.push_back(func_graph->NewCNode({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); + inputs2.push_back(func_graph->NewCNodeInOrder({NewValueNode(prim::kPrimGetAttr), item.first, NewValueNode(j)})); j++; } - inputs.push_back(func_graph->NewCNode(inputs2)); + inputs.push_back(func_graph->NewCNodeInOrder(inputs2)); } - return func_graph->NewCNode(inputs); + return func_graph->NewCNodeInOrder(inputs); } AnfNodePtr Map::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs) { diff --git a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc index 56815e4089..6b3b07020b 100644 --- a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc +++ b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.cc @@ -79,14 +79,35 @@ void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function & Register(types, py_fn); } +namespace { +bool HasUMonadType(const TypePtrList &types) { + auto types_size = types.size(); + // If UMonad is the only type, ignore it. + if (types_size > 1) { + auto last_type = types[types_size - 1]; + if (IsIdentidityOrSubclass(last_type, kUMonadType)) { + MS_LOG(DEBUG) << "Have Extra UMonad type"; + return true; + } + } + return false; +} +} // namespace + // Return Exact match if exists, else return non ambiguous sub class match // Return py::none() if matching is ambiguous -const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { +const std::pair MultitypeFuncGraph::SignMatch(const TypePtrList &types) { // Exact match for (auto &item : fn_cache_py_) { + bool has_extra_u_monad = false; TypePtrList sign = item.first; - if (sign.size() != types.size()) { - continue; + auto types_size = types.size(); + if (sign.size() != types_size) { + // Don't take the UMonad type into account. + has_extra_u_monad = (types_size > 1) && (sign.size() == (types_size - 1)) && HasUMonadType(types); + if (!has_extra_u_monad) { + continue; + } } auto match = true; for (size_t i = 0; i < sign.size(); ++i) { @@ -98,13 +119,14 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) { if (!match) { continue; } - return item.second; + return std::pair(item.second, has_extra_u_monad); } - return py::none(); + return std::pair(py::none(), false); } FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { - auto py_fn = SignMatch(types); + auto py_fn_pair = SignMatch(types); + auto py_fn = py_fn_pair.first; std::ostringstream buffer; buffer << types; if (!py_fn.is_none()) { @@ -113,6 +135,10 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str(); } MS_LOG(DEBUG) << "Find overload function " << buffer.str() << ", function: " << func_graph->ToString(); + if (py_fn_pair.second) { + MS_LOG(DEBUG) << "Add extra UMoand type for func_graph: " << func_graph->ToString(); + func_graph->add_parameter(); + } return func_graph; } auto stub = GenerateStubFunc(types); diff --git a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h index 15d8449cd7..279722e2fa 100644 --- a/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h +++ b/mindspore/ccsrc/frontend/operator/composite/multitype_funcgraph.h @@ -53,7 +53,7 @@ class MultitypeFuncGraph : public MetaFuncGraph { } private: - const py::function SignMatch(const TypePtrList &types); + const std::pair SignMatch(const TypePtrList &types); std::unordered_map fn_cache_; std::unordered_map fn_cache_py_; }; diff --git a/mindspore/ccsrc/frontend/operator/composite/unpack_call.cc b/mindspore/ccsrc/frontend/operator/composite/unpack_call.cc index 2fa4e3f780..502d4ec10a 100644 --- a/mindspore/ccsrc/frontend/operator/composite/unpack_call.cc +++ b/mindspore/ccsrc/frontend/operator/composite/unpack_call.cc @@ -79,7 +79,8 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_ << args_spec_list[index]->ToString(); } } - ret_graph->set_output(ret_graph->NewCNode(elems)); + // Add to order list to trace if fn_node had side effect. + ret_graph->set_output(ret_graph->NewCNodeInOrder(elems)); return ret_graph; } diff --git a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc index 58fd385e3c..b0c275354e 100644 --- a/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc +++ b/mindspore/ccsrc/frontend/operator/ops_front_infer_function.cc @@ -516,7 +516,7 @@ AbstractBasePtr InferImplMakeRange(const AnalysisEnginePtr &, const PrimitivePtr AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor + // Inputs: any value; CheckArgsSize(primitive->name(), args_spec_list, 1); return args_spec_list[0]->Clone(); } @@ -642,8 +642,8 @@ AbstractBasePtr InferImplMakeRecord(const AnalysisEnginePtr &, const PrimitivePt AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { - // Inputs: a tensor - CheckArgsSize(primitive->name(), args_spec_list, 2); + // Inputs: Ref, value, [universal] + CheckRequiredArgsSize(primitive->name(), args_spec_list, 2); MS_LOG(DEBUG) << "InferImplAssign " << args_spec_list[0]; auto type = args_spec_list[0]->BuildType(); @@ -654,6 +654,18 @@ AbstractBasePtr InferImplAssign(const AnalysisEnginePtr &, const PrimitivePtr &p } } +AbstractBasePtr InferImplLoad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: Ref/Tensor, universal + CheckArgsSize(primitive->name(), args_spec_list, 2); + auto ref_abs = dyn_cast(args_spec_list[0]); + if (ref_abs != nullptr) { + // Return tensor value if input is Ref. + return ref_abs->CloneAsTensor(); + } + return args_spec_list[0]->Broaden(); +} + REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(TypeOf, prim::kPrimTypeOf, InferImplTypeof); REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(HasType, prim::kPrimHasType, InferImplHasType); REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(MakeRecord, prim::kPrimMakeRecord, InferImplMakeRecord); @@ -676,5 +688,6 @@ REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(J, prim::kPrimJ, InferImplJ); REGISTER_FRONTENT_PRIMITIVE_EVAL_IMPL(BroadcastGradientArgs, prim::kPrimBroadcastGradientArgs, InferImplBroadcastGradientArgs); REGISTER_PRIMITIVE_EVAL_IMPL(Assign, prim::kPrimAssign, InferImplAssign); +REGISTER_PRIMITIVE_EVAL_IMPL(Load, prim::kPrimLoad, InferImplLoad); } // namespace abstract } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/ad/adjoint.cc b/mindspore/ccsrc/frontend/optimizer/ad/adjoint.cc index 559836e4fb..892b599416 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/adjoint.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/adjoint.cc @@ -36,7 +36,7 @@ Adjoint::Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphP MS_LOG(DEBUG) << "Add hole for " << primal->ToString() << " " << k_->ToString(); } - dout_hole_ = caller_->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), k_}); + dout_hole_ = caller_->NewCNodeInFront({NewValueNode(prim::GetPythonOps("zeros_like")), k_}); RegisterKUser(dout_hole_->cast(), 1); } diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index 811a53e4e9..ca825519f9 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -168,11 +168,28 @@ void DFunctor::BackPropagateSwitchLayer(const CNodePtr &cnode_morph, const CNode } } +static bool HasSideEffectBackProp(const CNodePtr &cnode) { + if (IsPrimitiveCNode(cnode)) { + const auto &prim = GetCNodePrimitive(cnode); + MS_EXCEPTION_IF_NULL(prim); + auto bprop_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP); + return bprop_flag; + } + return false; +} + void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint) { auto bprop = k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(static_cast(1))}); // Call with delimited continuation dout. - auto bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()}); + CNodePtr bprop_app; + if (HasSideEffectBackProp(cnode_morph)) { + // as MapMorphism is called recursively, so the order of bprop_app should reversed as visited order. + bprop_app = tape_->NewCNodeInFront({bprop, node_adjoint->dout()}); + tape_->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true); + } else { + bprop_app = tape_->NewCNode({bprop, node_adjoint->dout()}); + } node_adjoint->RegisterDoutUser(bprop_app, 1); // Special case for switch_layer if (IsPrimitiveCNode(cnode_morph, prim::kPrimSwitchLayer)) { @@ -358,10 +375,10 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor if (inputs_value.empty()) { return; } - if (inputs_value.size() != paras.size()) { - MS_LOG(EXCEPTION) << "Parameter size:" << paras.size() << " is not equal to inputs size:" << inputs_value.size(); + if (inputs_value.size() > paras.size()) { + MS_LOG(EXCEPTION) << "Parameter size:" << paras.size() << " but inputs size:" << inputs_value.size(); } - for (size_t i = 0; i < paras.size(); i++) { + for (size_t i = 0; i < inputs_value.size(); i++) { auto para_ref_size = manager->node_users()[paras[i]].size(); auto input_value = inputs_value[i]; if (para_ref_size > 0 && input_value.first != nullptr) { @@ -415,7 +432,7 @@ bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) { } void DFunctor::MapFreeMorphism() { - // Handle cnode not attached to output, that might be refered in other functions. + // Handle cnode not attached to output, that might be referred in other functions. for (auto &node : primal_graph_->nodes()) { if (!IsFreeMorphism(node)) { continue; @@ -522,7 +539,10 @@ FuncGraphPtr DFunctor::KUserDefined(const FuncGraphPtr &primal) { MS_LOG(EXCEPTION) << "User defined Cell bprop " << primal->ToString() << " in scope " << primal->output()->scope()->name() << " does not support Parameter data type."; } - auto fg = g_k_prims.KUserDefinedCellBprop(bprop_graph); + bprop_graph->set_flag(mindspore::kFuncGraphFlagBackPropEntry, true); + bprop_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true); + + auto fg = g_k_prims.KUserDefinedCellBprop(bprop_graph, primal); if (fg == nullptr) { MS_LOG(EXCEPTION) << "Failed to expand user defined Cell bprop " << primal->ToString() << " in scope " << primal->output()->scope()->name() << "."; @@ -553,8 +573,9 @@ AnfNodePtr DFunctor::MapPrimitiveToK(const CNodePtr &primitive_user, size_t inde // Map Primitive to K auto value_node = primal->cast(); auto prim = GetValueNode(value_node); - if (prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) { - MS_LOG(DEBUG) << "Meet a kPrimStopGradient " << prim->ToString() << "."; + if ((prim->Hash() == prim::kPrimStopGradient->Hash() && prim->name() == prim::kPrimStopGradient->name()) || + (prim->Hash() == prim::kPrimUpdateState->Hash() && prim->name() == prim::kPrimUpdateState->name())) { + MS_LOG(DEBUG) << "Should stop gradient for " << prim->ToString(); need_cut_ = true; } auto k_prim = g_k_prims.KPrimitive(primitive_user, value_node, resources_); @@ -748,8 +769,11 @@ void DFunctor::CallDoutHoleOnTape() { } } } + FuncGraphPtr DFunctor::k_graph() { return k_graph_; } +FuncGraphPtr DFunctor::tape() { return tape_; } + void DFunctor::BroadCastStopFlag() { // As stop set expanding, all directly or indirectly stopped CNode will be cut off while (need_cut_) { @@ -759,7 +783,8 @@ void DFunctor::BroadCastStopFlag() { auto cnode = node->cast(); if (!cnode->stop_gradient()) { // Cut off the cnode only when it's not referred any more - if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || AllReferencesStopped(cnode)) { + if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode, prim::kPrimUpdateState) || + AllReferencesStopped(cnode)) { MS_LOG(DEBUG) << "Set stop gradient flag for " << cnode->ToString() << "."; cnode->set_stop_gradient(true); // The stop set changed, more cut required @@ -786,28 +811,133 @@ bool DFunctor::AllReferencesStopped(const CNodePtr &node) { return true; } -// To replace the primal graph with k graph +static std::pair FindPrimalJPair(const FuncGraphManagerPtr &manager, + const FuncGraphPtr &primal_graph) { + CNodePtr primal_user = nullptr; + CNodePtr j_user = nullptr; + auto &node_user_map = manager->node_users(); + // Search primal graph user cnodes. + for (auto &entry : primal_graph->func_graph_cnodes_index()) { + auto cnode = entry.first->first->cast(); + auto index = entry.first->second; + if (index == 0) { + // To find real calling. + primal_user = cnode; + } else if (IsPrimitive(cnode->inputs().at(0), prim::kPrimJ)) { + // To find J user. + auto it = node_user_map.find(cnode); + if (it == node_user_map.end()) { + MS_LOG(EXCEPTION) << "J CNode not used {" << cnode->DebugString(2) << "/" << index << "}"; + } + auto &j_users = it->second; + auto size = j_users.size(); + if (size != 1) { + MS_LOG(EXCEPTION) << "Wrong J CNode use size " << size << " {" << cnode->DebugString(2) << "/" << index << "}"; + } + j_user = j_users.begin()->first->cast(); + } + if (j_user != nullptr && primal_user != nullptr) { + break; + } + } + return {primal_user, j_user}; +} + +static void RemovePrimalUpdateStates(const FuncGraphManagerPtr &manager, const CNodePtr &primal_call) { + auto &node_users = manager->node_users(); + auto iter = node_users.find(primal_call); + if (iter == node_users.end()) { + // Skip if user of primal_call not found. + return; + } + // Find UpdateState nodes after the primal call. + std::vector update_states; + for (auto &user : iter->second) { + auto &user_node = user.first; + if (IsPrimitiveCNode(user_node, prim::kPrimUpdateState)) { + update_states.emplace_back(user_node->cast()); + } + } + // Remove UpdateStates by replace them with their monad input. + for (auto &update_state : update_states) { + auto &input_monad = update_state->inputs().at(1); + manager->Replace(update_state, input_monad); + } +} + +static bool CopyMonadArguments(const CNodePtr &primal_user, const CNodePtr &j_user) { + auto &primal_inputs = primal_user->inputs(); + auto &j_user_inputs = j_user->inputs(); + bool has_monad = false; + for (size_t i = 1; i < primal_inputs.size(); ++i) { + auto &input = primal_inputs.at(i); + if (HasAbstractMonad(input)) { + // Copy monad input from primal to j_user. + j_user->set_input(i, input); + has_monad = true; + } else if (input != j_user_inputs.at(i)) { + // Skip if there are different non-monad inputs. + return false; + } + } + return has_monad; +} + +// +// To replace the primal graph with k graph. +// Convert: +// x = primal(args, u0) +// u1 = update_state(u0, x) +// ... +// tuple = K(args, u1) +// u2 = update_state(u1, tuple) +// ... +// To: +// tuple = K(args, u0) +// x = get_item(tuple, 0) +// ... +// tuple = K(args, u0) +// u2 = update_state(u0, tuple) +// ... +// void DFunctor::EliminatePrimalGraph() { + // Find primal user and paired J user cnodes. + auto manager = primal_graph_->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto [primal_user, j_user] = FindPrimalJPair(manager, primal_graph_); + if (primal_user == nullptr || j_user == nullptr) { + // Skip if one of them not found. + return; + } + // Check input size. + if (primal_user->size() != j_user->size()) { + MS_LOG(WARNING) << "Input size incorrect, primal:" << primal_user->DebugString() + << " juser:" << j_user->DebugString(); + return; + } + // Replace primal graph with k graph. auto k_vnode = NewValueNode(k_graph_); - auto idx0 = NewValueNode(SizeToLong(0)); + auto primal_abs = primal_user->abstract(); + primal_user->set_input(0, k_vnode); + primal_user->set_abstract(j_user->abstract()); + + // If both inputs are same except monads, we copy primal monad args to k graph + // so that they can be combined in CSE (common subexpression elimination) pass. + const bool has_monad = CopyMonadArguments(primal_user, j_user); + // Remove the UpdateState nodes after primal_user if need. + if (has_monad) { + RemovePrimalUpdateStates(manager, primal_user); + } + + // Insert tuple_getitem after primal user cnode. + auto construct_wrapper = primal_user->func_graph(); + auto tuple_getitem = NewValueNode(prim::kPrimTupleGetItem); auto imm0 = std::make_shared(0); + auto idx0 = NewValueNode(SizeToLong(0)); idx0->set_abstract(std::make_shared(imm0)); - auto manager = primal_graph_->manager(); - auto users = primal_graph_->func_graph_cnodes_index(); - for (auto &it : users) { - auto cnode = it.first->first->cast(); - auto index = it.first->second; - auto vnode = cnode->inputs()[index]; - if (index != 0) { - MS_LOG(DEBUG) << "Primal is used but not called, at {" << cnode->DebugString(3) << "/" << index << "}"; - continue; - } - cnode->set_input(0, k_vnode); // Replace primal graph with k graph - auto construct_wrapper = cnode->func_graph(); - TraceGuard trace_guard(std::make_shared(cnode->debug_info())); - auto getitem0 = construct_wrapper->NewCNode({NewValueNode(prim::kPrimTupleGetItem), cnode, idx0}); - manager->Replace(cnode, getitem0); - } + auto getitem0 = construct_wrapper->NewCNode({tuple_getitem, primal_user, idx0}); + getitem0->set_abstract(primal_abs); + manager->Replace(primal_user, getitem0); } } // namespace ad } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h index 229c329f4b..47497b7bea 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.h @@ -60,6 +60,7 @@ class DFunctor : public std::enable_shared_from_this { // Map morphism in D category to K category. void MapMorphism(); FuncGraphPtr k_graph(); + FuncGraphPtr tape(); // Construct user defined k object. FuncGraphPtr KUserDefined(const FuncGraphPtr &primal); // Register functor objects to form a global view. @@ -138,7 +139,9 @@ class KPrim { FuncGraphPtr KPrimitive(const CNodePtr &primal_user, const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim); - FuncGraphPtr KUserDefinedCellBprop(FuncGraphPtr bprop); + // bprop_fg and primal_fg in bprop_fg's transforms are FuncGraph just after convert. + // current_primal_fg is the specialized and AutoMonaded primal_fg. + FuncGraphPtr KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg); void clear() { bprop_registry_meta_.clear(); @@ -151,11 +154,19 @@ class KPrim { FuncGraphPtr FakeBprop(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources); // Given a bprop rule, do the K mapping. + // current_primal_fg is only valid for user defined bprop for Cell, not for Primitive. + // Refer the comment in KUserDefinedCellBprop. template - FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g, const CNodePtr &cnode); - AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg); - void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, - std::vector *const transf_args); + FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g, const FuncGraphPtr ¤t_primal_fg, + const CNodePtr &cnode); + AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg); + void TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, + const PrimitivePtr &primitive, const FuncGraphPtr &outer, + std::vector *const transf_args); + template + void TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, + const T ¤t_primal_fg, const FuncGraphPtr &outer, + std::vector *const transf_args); void CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check); Registry bprop_registry_; @@ -163,7 +174,8 @@ class KPrim { }; template -FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const CNodePtr &cnode) { +FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg, + const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(primal); MS_EXCEPTION_IF_NULL(bprop_fg); CheckBprop(bprop_fg, primal->ToString()); @@ -177,7 +189,13 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, cons cloned_bprop_fg->debug_info()->set_name(""); cloned_bprop_fg->debug_info()->set_trace_info(std::make_shared(debug_info)); - AnfNodePtr bout = BuildOutput(cloned_bprop_fg); + // Make sure (out, dout) provided. + if (cloned_bprop_fg->parameters().size() < 2) { + MS_LOG(EXCEPTION) << "Primitive or Cell " << primal->ToString() + << " bprop requires out and dout at least, but only got " << cloned_bprop_fg->parameters().size() + << " params. NodeInfo: " << trace::GetDebugInfo(cloned_bprop_fg->debug_info()); + } + AnfNodePtr bout = BuildOutput(cloned_bprop_fg, current_primal_fg); cloned_bprop_fg->set_output(bout); FuncGraphPtr outer = nullptr; @@ -190,24 +208,30 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, cons auto mng = Manage({cloned_bprop_fg, outer}, false); - // Make sure (out, dout) provided. - if (cloned_bprop_fg->parameters().size() < 2) { - MS_LOG(EXCEPTION) << "Primitive or Cell " << primal->ToString() - << " bprop requires out and dout at least, but only got " << cloned_bprop_fg->parameters().size() - << " params. NodeInfo: " << trace::GetDebugInfo(cloned_bprop_fg->debug_info()); - } - // In a bprop definition, the last two param should be out and dout. - auto dout = cloned_bprop_fg->parameters()[cloned_bprop_fg->parameters().size() - 1]; - auto out_param = cloned_bprop_fg->parameters()[cloned_bprop_fg->parameters().size() - 2]; + auto param_size = cloned_bprop_fg->parameters().size(); + auto param_num = param_size - 1; + auto dout = cloned_bprop_fg->parameters()[param_num]; + param_num--; + auto out_param = cloned_bprop_fg->parameters()[param_num]; + std::vector transf_args; - TransformArgs(mng, cloned_bprop_fg, outer, &transf_args); - (void)transf_args.insert(transf_args.begin(), NewValueNode(primal)); + if constexpr (std::is_same::value) { + PrimitivePtr primitive = primal; + TransformArgsForPrimitive(mng, cloned_bprop_fg, primal, outer, &transf_args); + (void)transf_args.insert(transf_args.begin(), NewValueNode(primal)); + } else { + TransformArgsForFuncGraph(mng, cloned_bprop_fg, current_primal_fg, outer, &transf_args); + (void)transf_args.insert(transf_args.begin(), NewValueNode(current_primal_fg)); + } CNodePtr out_value = nullptr; if (cnode != nullptr) { // Set equiv debug info. for Primitive CNode out. TraceGuard trace_guard(std::make_shared(cnode->debug_info())); out_value = outer->NewCNode(transf_args); + if constexpr (std::is_same::value) { + out_value->CloneUserData(cnode); + } } else { out_value = outer->NewCNode(transf_args); } diff --git a/mindspore/ccsrc/frontend/optimizer/ad/grad.cc b/mindspore/ccsrc/frontend/optimizer/ad/grad.cc index 62cc23a2a0..851341c9bc 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/grad.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/grad.cc @@ -55,6 +55,8 @@ FuncGraphPtr Grad(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePt f->MapMorphism(); f->Finish(); auto res = f->k_graph(); + auto tape = f->tape(); + tape->set_flag(mindspore::kFuncGraphFlagBackPropEntry, true); if (is_top) { DFunctor::Clear(); } diff --git a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc index 6cc6468b12..47713ff6de 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/kprim.cc @@ -67,6 +67,11 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) { MS_LOG(ERROR) << "Fail to parse bprop function for " << prim->name() << "."; return nullptr; } + auto bprop_flag = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_BACKPROP); + if (bprop_flag) { + func_graph->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true); + } + return func_graph; } @@ -102,6 +107,38 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) { MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << "."; } +static void AppendMonadOutput(const FuncGraphPtr &bprop_fg, const AnfNodePtr &monad) { + const auto &output = bprop_fg->output(); + MS_EXCEPTION_IF_NULL(output); + auto output_cnode = output->cast(); + if (output_cnode != nullptr) { + // If output_cnode has the form like (make_tuple, x, y). + output_cnode->add_input(monad); + return; + } + // If output is an empty tuple, create a (make_tuple, monad) as the new output. + auto make_tuple = NewValueNode(prim::kPrimMakeTuple); + output_cnode = bprop_fg->NewCNode({make_tuple, monad}); + bprop_fg->set_output(output_cnode); +} + +// Append U or/and IO monad to output of Bprop funcgraph. +static void AdjustForAutoMonad(const PrimitivePtr &prim, const FuncGraphPtr &bprop_fg) { + auto effect_info = GetPrimEffectInfo(prim); + if (effect_info.memory) { + MS_LOG(DEBUG) << "Append U monad for Bprop FuncGraph of Primitive " << prim->ToString(); + auto u = NewValueNode(kUMonad); + u->set_abstract(kUMonad->ToAbstract()); + AppendMonadOutput(bprop_fg, u); + } + if (effect_info.io) { + MS_LOG(DEBUG) << "Append IO monad for Bprop FuncGraph of Primitive " << prim->ToString(); + auto io = NewValueNode(kIOMonad); + io->set_abstract(kIOMonad->ToAbstract()); + AppendMonadOutput(bprop_fg, io); + } +} + FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources) { if (!IsValueNode(value_node)) { @@ -141,8 +178,8 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_ } } } - - auto expanded_fg = BpropToK(prim, bprop_fg, cnode); + AdjustForAutoMonad(prim, bprop_fg); + auto expanded_fg = BpropToK(prim, bprop_fg, nullptr, cnode); if (expanded_fg == nullptr) { MS_LOG(EXCEPTION) << "Failed convert " << prim->name() << " prim bprop function to J expanded func graph. NodeInfo: " @@ -152,7 +189,23 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_ return expanded_fg; } -AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg) { +AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg) { + // current_primal_fg may have extra parameters like u_monad, io_monad + std::vector extra_args; + // caller had checked size() - 2 is greater than 0. + auto bprop_fg_param_size = bprop_fg->parameters().size() - 2; + if (current_primal_fg != nullptr && bprop_fg_param_size < current_primal_fg->parameters().size()) { + auto current_primal_fg_param_size = current_primal_fg->parameters().size(); + MS_LOG(DEBUG) << "Current Primal FuncGraph may have extra parameters(U or IO monad) which bprop don't define, so " + "Insert it. Extra parameters size: " + << current_primal_fg_param_size - bprop_fg_param_size; + for (auto i = bprop_fg_param_size; i < current_primal_fg_param_size; ++i) { + const auto &primal_node = current_primal_fg->parameters()[i]; + auto extra_node = bprop_fg->NewCNode({NewValueNode(prim::GetPythonOps("zeros_like")), primal_node}); + extra_args.push_back(extra_node); + MS_LOG(DEBUG) << "Insert to bprop_fg for node: " << primal_node->DebugString(); + } + } // bprop_fg has been checked in caller if (IsPrimitiveCNode(bprop_fg->output(), prim::kPrimMakeTuple)) { // Set bprop output as (env, dx, dy, dz, ...) @@ -163,22 +216,33 @@ AnfNodePtr KPrim::BuildOutput(const FuncGraphPtr &bprop_fg) { args.push_back(NewValueNode(prim::kPrimMakeTuple)); args.push_back(NewValueNode(newenv)); (void)args.insert(args.end(), inputs.begin() + 1, inputs.end()); + if (!extra_args.empty()) { + args.insert(args.end(), extra_args.cbegin(), extra_args.cend()); + } return NewCNode(args, bprop_fg); } // Set bprop output as (env, dx) std::string model_name("mindspore.ops.composite.multitype_ops.add_impl"); std::string python_ops("_tuple_add"); - auto tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg); - return NewCNode({NewValueNode(prim::GetPythonOps(python_ops, model_name)), tuple, bprop_fg->output()}, bprop_fg); + auto tuple_env = NewCNode({NewValueNode(prim::kPrimMakeTuple), NewValueNode(newenv)}, bprop_fg); + auto tuple_add_ops = NewValueNode(prim::GetPythonOps(python_ops, model_name)); + if (!extra_args.empty()) { + extra_args.insert(extra_args.begin(), NewValueNode(prim::kPrimMakeTuple)); + auto extra_tuple = NewCNode(extra_args, bprop_fg); + auto old_output_extra = NewCNode({tuple_add_ops, bprop_fg->output(), extra_tuple}, bprop_fg); + return NewCNode({tuple_add_ops, tuple_env, old_output_extra}, bprop_fg); + } + + return NewCNode({tuple_add_ops, tuple_env, bprop_fg->output()}, bprop_fg); } -void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, - std::vector *const transf_args) { - MS_EXCEPTION_IF_NULL(mng); +static void TransformNormalArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer, + std::vector *const transf_args) { // bprop_fg has been checked in caller // transform except the last 2 parameters: out, dout. - for (size_t i = 0; i < bprop_fg->parameters().size() - 2; ++i) { + auto bprop_fg_param_size = bprop_fg->parameters().size() - 2; + for (size_t i = 0; i < bprop_fg_param_size; ++i) { auto p = bprop_fg->parameters()[i]; MS_EXCEPTION_IF_NULL(p); @@ -189,6 +253,60 @@ void KPrim::TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bp transf_args->push_back(transf_p); } } +void KPrim::TransformArgsForPrimitive(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, + const PrimitivePtr &primitive, const FuncGraphPtr &outer, + std::vector *const transf_args) { + MS_EXCEPTION_IF_NULL(mng); + TransformNormalArgs(mng, bprop_fg, outer, transf_args); + // Fprop_fg for Primitive with side effect should append extra U or IO monad parameter. + auto effect_info = GetPrimEffectInfo(primitive); + if (effect_info.memory) { + MS_LOG(DEBUG) << "Append U monad to Fprop FuncGraph for Primitive " << primitive->ToString(); + auto transf_p = outer->add_parameter(); + transf_args->push_back(transf_p); + } + if (effect_info.io) { + MS_LOG(DEBUG) << "Append IO monad to Fprop FuncGraph for Primitive " << primitive->ToString(); + auto transf_p = outer->add_parameter(); + transf_args->push_back(transf_p); + } +} + +template +void KPrim::TransformArgsForFuncGraph(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, + const T ¤t_primal_fg, const FuncGraphPtr &outer, + std::vector *const transf_args) { + MS_EXCEPTION_IF_NULL(mng); + TransformNormalArgs(mng, bprop_fg, outer, transf_args); + auto bprop_fg_param_size = bprop_fg->parameters().size() - 2; + // current_primal_fg may have extra parameters after AutoMonad + const auto ¤t_primal_fg_params = current_primal_fg->parameters(); + if (bprop_fg_param_size < current_primal_fg_params.size()) { + for (auto i = bprop_fg_param_size; i < current_primal_fg_params.size(); ++i) { + auto p = current_primal_fg_params[i]; + MS_EXCEPTION_IF_NULL(p); + // extra parameters should be Monad. + if (!HasAbstractMonad(p)) { + continue; + } + MS_LOG(DEBUG) << "Function " << current_primal_fg->ToString() + << ", has extra monad parameter: " << p->DebugString() + << ", abstract: " << p->abstract()->ToString(); + + TraceGuard trace_guard(std::make_shared(p->debug_info())); + auto transf_p = outer->add_parameter(); + + (void)mng->Replace(p, transf_p); + transf_args->push_back(transf_p); + } + } + if (transf_args->size() != current_primal_fg_params.size()) { + MS_EXCEPTION(TypeError) << "Function " << current_primal_fg->ToString() + << ", The number of parameter of this primal function is " + << current_primal_fg_params.size() << ", but the number of parameters of bprop is " + << bprop_fg_param_size; + } +} void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check) { auto context = MsContext::GetInstance(); @@ -218,14 +336,16 @@ void KPrim::CheckBprop(const FuncGraphPtr &bprop_fg, const string &prim_to_check bprop_fg->set_output(bprop_out); } -FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr bprop_fg) { +FuncGraphPtr KPrim::KUserDefinedCellBprop(const FuncGraphPtr &bprop_fg, const FuncGraphPtr ¤t_primal_fg) { MS_EXCEPTION_IF_NULL(bprop_fg); - auto fprop_fg = bprop_fg->transforms().find("primal")->second.func_graph(); - auto expanded_fg = BpropToK(fprop_fg, bprop_fg, nullptr); + // primal_fg is FuncGraph just after convert. Refer ConvertCellObjToFuncGraph. + // current_primal_fg is specalized and AutoMoaded primal_fg; + auto primal_fg = bprop_fg->transforms().find("primal")->second.func_graph(); + auto expanded_fg = BpropToK(primal_fg, bprop_fg, current_primal_fg, nullptr); if (expanded_fg == nullptr) { - MS_LOG(EXCEPTION) << "Failed convert " << fprop_fg->ToString() + MS_LOG(EXCEPTION) << "Failed convert " << primal_fg->ToString() << " Cell bprop function to K expanded func graph. NodeInfo: " - << trace::GetDebugInfo(fprop_fg->debug_info()); + << trace::GetDebugInfo(primal_fg->debug_info()); } return expanded_fg; } @@ -283,6 +403,20 @@ FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::Re MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString(); } auto inputs_num = cnode->first->cast()->inputs().size() - 1; + auto effect_info = GetPrimEffectInfo(prim); + // Don't add U or IO monad parameters as it will be added later. + size_t monad_params_size = 0; + if (effect_info.memory) { + monad_params_size++; + } + if (effect_info.io) { + monad_params_size++; + } + if (inputs_num < monad_params_size) { + MS_LOG(EXCEPTION) << "Arguments number should be greater than or equal to " << monad_params_size + << ", but the CNode is: " << cnode->first->DebugString(); + } + inputs_num -= monad_params_size; auto func_graph = std::make_shared(); std::vector outputs; diff --git a/mindspore/ccsrc/frontend/optimizer/control_depend.cc b/mindspore/ccsrc/frontend/optimizer/control_depend.cc deleted file mode 100644 index 871cc7d003..0000000000 --- a/mindspore/ccsrc/frontend/optimizer/control_depend.cc +++ /dev/null @@ -1,121 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "frontend/optimizer/control_depend.h" - -#include -#include -#include -#include - -#include "frontend/optimizer/optimizer.h" - -namespace mindspore { -namespace opt { -std::vector DoControlDepend(const FuncGraphPtr &graph, const CNodePtr &return_node, - const std::vector &effect_index, const std::vector &cnodes) { - std::vector depend_nodes{NewValueNode(prim::kPrimDepend), return_node->input(1)}; - std::vector make_tuple{NewValueNode(prim::kPrimMakeTuple)}; - size_t effect_size = effect_index.size(); - for (size_t i = 0; i < effect_size; i++) { - size_t pre_index = 0; - if (i > 0) { - pre_index = effect_index[i - 1] + 1; - } - size_t this_index = effect_index[i]; - size_t last_index = cnodes.size() - 2; - if (i < effect_size - 1) { - last_index = effect_index[i + 1]; - } - - if (this_index > pre_index) { - std::vector pre_segment; - for (size_t k = pre_index; k < this_index; k++) { - // Skip depend, make_tuple, and tuple_get_item, because these primitives are not real operator in GE. - if (IsPrimitiveCNode(cnodes[k], prim::kPrimDepend) || IsPrimitiveCNode(cnodes[k], prim::kPrimMakeTuple) || - IsPrimitiveCNode(cnodes[k], prim::kPrimTupleGetItem)) { - continue; - } - pre_segment.push_back(cnodes[k]); - } - auto roots = FindRoots(pre_segment); - for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) { - AnfNodePtr control_depend = - graph->NewCNode({NewValueNode(prim::kPrimControlDepend), *iter, cnodes[this_index]}); - make_tuple.push_back(control_depend); - } - } - if (last_index > this_index) { - std::vector last_segment; - for (size_t k = this_index + 1; k <= last_index; k++) { - // Skip depend, make_tuple, and tuple_get_item, because these primitives are not real operator in GE. - if (IsPrimitiveCNode(cnodes[k], prim::kPrimDepend) || IsPrimitiveCNode(cnodes[k], prim::kPrimMakeTuple) || - IsPrimitiveCNode(cnodes[k], prim::kPrimTupleGetItem)) { - continue; - } - last_segment.push_back(cnodes[k]); - } - auto leaves = FindLeaves(last_segment); - for (auto iter = leaves->begin(); iter != leaves->end(); (void)iter++) { - AnfNodePtr control_depend = - graph->NewCNode({NewValueNode(prim::kPrimControlDepend), cnodes[this_index], *iter}); - make_tuple.push_back(control_depend); - } - } - } - depend_nodes.push_back(graph->NewCNode(make_tuple)); - return depend_nodes; -} - -void AddControlDepend(const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - std::list orders = graph->GetOrderedCnodes(); - std::vector cnodes(orders.begin(), orders.end()); - size_t cnodes_size = cnodes.size(); - // get effect index of cnodes - std::vector effect_index{}; - for (size_t i = 0; i < cnodes_size; i++) { - if (graph->HasEffect(cnodes[i])) { - effect_index.push_back(i); - } - } - if (effect_index.empty()) { - return; - } - AnfNodePtr last_node = cnodes[cnodes_size - 1]; - CNodePtr return_node; - if (last_node->isa()) { - return_node = last_node->cast(); - } - MS_EXCEPTION_IF_NULL(return_node); - if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) { - MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode."; - } - if (return_node->inputs().size() < 2) { - MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2."; - } - - auto depend_node_inputs = DoControlDepend(graph, return_node, effect_index, cnodes); - auto depend_cnode = graph->NewCNode(depend_node_inputs); - depend_cnode->set_abstract(depend_cnode->input(1)->abstract()); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - if (!manager->Replace(return_node->input(1), depend_cnode)) { - MS_LOG(EXCEPTION) << "Depend replace node failed"; - } -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/cse.cc b/mindspore/ccsrc/frontend/optimizer/cse.cc index 35ce7e53b2..d227386212 100644 --- a/mindspore/ccsrc/frontend/optimizer/cse.cc +++ b/mindspore/ccsrc/frontend/optimizer/cse.cc @@ -21,10 +21,13 @@ #include #include #include +#include +#include #include "abstract/abstract_function.h" #include "utils/flags.h" #include "utils/utils.h" +#include "base/core_ops.h" namespace mindspore { /* namespace to support opt */ @@ -116,6 +119,245 @@ bool CSE::BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const { return changed; } + +std::vector> GenerateLoadGroups(const FuncGraphPtr &fg, const std::vector &toposet, + std::vector *need_replace_loads) { + std::unordered_map load_groups_record; + std::vector> load_groups; + std::unordered_set unload_users_record; + for (size_t i = 0; i < toposet.size(); i++) { + auto &node = toposet[i]; + auto cnode = node->cast(); + if (cnode == nullptr) { + continue; + } + if (!IsPrimitiveCNode(cnode, prim::kPrimLoad)) { + for (const auto &input : cnode->inputs()) { + if (input->isa()) { + unload_users_record.insert(input); + } + } + continue; + } + // Exclude free variable node. + if (cnode->func_graph() != fg) { + continue; + } + auto load_param = cnode->input(1); + // first time get same input1 of load. + if (load_groups_record.find(load_param) == load_groups_record.end()) { + load_groups_record[load_param] = load_groups.size(); + load_groups.push_back({i}); + if (unload_users_record.find(load_param) == unload_users_record.end()) { + need_replace_loads->emplace_back(cnode); + } + } else { + // not first time get same input1 of load + load_groups[load_groups_record[load_param]].push_back(i); + } + } + return load_groups; +} + +std::vector> SplitGroup(const std::vector &toposet, const std::vector &group) { + if (group.size() <= 1) { + return {}; + } + auto load_param = toposet[group.back()]->cast()->input(1); + size_t cur_load_index = 1; + size_t pre_load_index = 0; + std::vector cur_group = {group[pre_load_index]}; + std::vector> split_groups; + while (cur_load_index < group.size()) { + const auto &cur_load = group[cur_load_index]; + const auto &prev_load = group[pre_load_index]; + const auto param_used_by_other = + std::any_of(toposet.begin() + prev_load, toposet.begin() + cur_load, [&load_param](const AnfNodePtr &node) { + if (!node->isa()) { + return false; + } + if (IsPrimitiveCNode(node, prim::kPrimLoad)) { + return false; + } + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + return std::any_of(inputs.begin(), inputs.end(), + [&load_param](const AnfNodePtr &input) { return load_param == input; }); + }); + if (param_used_by_other) { + split_groups.push_back(cur_group); + cur_group.clear(); + } + cur_group.push_back(cur_load); + pre_load_index++; + cur_load_index++; + } + // push back the last splited group. + split_groups.push_back(cur_group); + return split_groups; +} + +// Pattern1====================================== +// a = Load(para1, u1) +// ... +// b = Load(para1, u2) +// u3 = UpdateState(u2, b) +//==> +// delete the UpdateState +void DeleteLoadUserUpdateState(const FuncGraphManagerPtr &manager, const AnfNodePtr &load_user, + const AnfNodePtr &load) { + const auto &load_cnode = load->cast(); + const auto &u = load_cnode->input(2); + manager->Replace(load_user, u); +} + +// Pattern2====================================== +// a = Load(para1, u1) +// ... +// b = Load(para1, u2) +// t = make_tuple(x, b) +// u3 = UpdateState(u2, t) +//==> +// a = Load(para1, u1) +// ... +// b = Load(para1, u2) +// u3 = UpdateState(u2, x) +void DeleteLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const CNodePtr &make_tuple, const AnfNodePtr &load) { + AnfNodePtr other_input = nullptr; + for (size_t i = 1; i < make_tuple->size(); i++) { + if (make_tuple->input(i) != load) { + other_input = make_tuple->input(i); + break; + } + } + MS_EXCEPTION_IF_NULL(other_input); + manager->Replace(make_tuple, other_input); +} + +// Pattern3====================================== +// a = Load(para1, u1) +// ... +// b = Load(para1, u2) +// t = make_tuple(x, y, b, z) +// u3 = UpdateState(u2, t) +//==> +// a = Load(para1, u1) +// ... +// b = Load(para1, u2) +// t = make_tuple(x, y, z) +// u3 = UpdateState(u2, t) +void ReplaceLoadUserMakeTuple(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const CNodePtr &make_tuple, + const AnfNodePtr &load) { + auto &make_tuple_inputs = make_tuple->inputs(); + std::vector new_make_tuple_inputs; + (void)std::copy_if(make_tuple_inputs.begin(), make_tuple_inputs.end(), std::back_inserter(new_make_tuple_inputs), + [load](const AnfNodePtr &input) { return load != input; }); + const auto &new_make_tuple = fg->NewCNode(new_make_tuple_inputs); + manager->Replace(make_tuple, new_make_tuple); +} + +void ReplaceLoadUser(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, const AnfNodePtr &load) { + auto load_users = manager->node_users()[load]; + for (const auto &load_user : load_users) { + // Pattern1 + if (IsPrimitiveCNode(load_user.first, prim::kPrimUpdateState)) { + DeleteLoadUserUpdateState(manager, load_user.first, load); + continue; + } + if (IsPrimitiveCNode(load_user.first, prim::kPrimMakeTuple)) { + const auto &make_tuple = load_user.first->cast(); + auto &maketuple_users = manager->node_users()[make_tuple]; + auto maketuple_as_input_of_update = + maketuple_users.size() == 1 && IsPrimitiveCNode(maketuple_users.back().first, prim::kPrimUpdateState); + if (!maketuple_as_input_of_update) { + continue; + } + // Pattern2 + if (make_tuple->size() == 3) { + DeleteLoadUserMakeTuple(manager, make_tuple, load); + continue; + } + // Pattern3 + if (make_tuple->size() > 3) { + ReplaceLoadUserMakeTuple(manager, fg, make_tuple, load); + } + } + } +} + +bool ReplaceSameGroupLoad(const FuncGraphManagerPtr &manager, const FuncGraphPtr &fg, + const std::vector &toposet, const std::vector &group) { + if (group.size() <= 1) { + return false; + } + const auto &main = toposet[group[0]]; + for (size_t i = 1; i < group.size(); i++) { + ReplaceLoadUser(manager, fg, toposet[group[i]]); + manager->Replace(toposet[group[i]], main); + } + return true; +} + +AnfNodePtr GetFirstMonad(const FuncGraphPtr &fg) { + auto ¶ms = fg->parameters(); + auto end = (params.size() > 1) ? (params.rbegin() + 2) : params.rend(); + auto iter = std::find_if(params.rbegin(), end, [](const AnfNodePtr ¶) { return HasAbstractUMonad(para); }); + if (iter != end) { + return *iter; + } + auto monad = NewValueNode(kUMonad); + monad->set_abstract(kUMonad->ToAbstract()); + return monad; +} + +// Replace UpdateStates with U for first load. +// Covert: +// u1 = UpdateState(u, c) +// p1 = Load(para1, u1) // first load for para1 +// To: +// u1 = UpdateState(u, c) +// p1 = Load(para1, u') // u' is first monad in graph or new monad +void ReplaceUpdateStateForLoad(const FuncGraphPtr &fg, const std::vector &need_replace_loads) { + constexpr size_t second_input_index = 2; + auto monad = GetFirstMonad(fg); + for (const auto &load_node : need_replace_loads) { + if (!IsPrimitiveCNode(load_node, prim::kPrimLoad)) { + continue; + } + auto update_state = load_node->cast()->input(second_input_index); + if (!IsPrimitiveCNode(update_state, prim::kPrimUpdateState)) { + continue; + } + auto mgr = fg->manager(); + mgr->SetEdge(load_node, second_input_index, monad); + } +} + +// Node1{primLoad,X,Y1},...,Node{Node's input != X},...,Node2{primLoad,X,Y2},... => +// Node1{primLoad,X,Y1},...,Node{Nodes' input != X},...,Node1,... +bool CSE::ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const { + auto changed = false; + for (const FuncGraphPtr &fg : manager->func_graphs()) { + std::vector toposet = TopoSort(fg->get_return()); + std::vector need_replace_loads; + std::vector> load_groups = GenerateLoadGroups(fg, toposet, &need_replace_loads); + ReplaceUpdateStateForLoad(fg, need_replace_loads); + // split group if there is no-load node between two load nodes. + std::vector> need_merge_loads; + for (auto &group : load_groups) { + auto groups = SplitGroup(toposet, group); + need_merge_loads.insert(need_merge_loads.end(), groups.begin(), groups.end()); + } + for (auto &group : need_merge_loads) { + const bool replaced = ReplaceSameGroupLoad(manager, fg, toposet, group); + if (!changed && replaced) { + changed = true; + } + } + } + return changed; +} + // The op like print, summary, or the op do not has true output, and always as a depend node input. static bool HasSideEffect(const AnfNodePtr &node) { auto prim = GetCNodePrimitive(node); @@ -255,8 +497,9 @@ bool CSE::DoReplace(const FuncGraphManagerPtr manager, const std::vectorAddFuncGraph(root); - - return BuildOrderGroupAndDoReplace(manager); + auto change1 = ReplaceAutoMonadNode(manager); + auto change2 = BuildOrderGroupAndDoReplace(manager); + return change1 || change2; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/cse.h b/mindspore/ccsrc/frontend/optimizer/cse.h index abfcd635e9..12341bbfb8 100644 --- a/mindspore/ccsrc/frontend/optimizer/cse.h +++ b/mindspore/ccsrc/frontend/optimizer/cse.h @@ -42,6 +42,7 @@ class CSE { private: bool BuildOrderGroupAndDoReplace(const FuncGraphManagerPtr manager) const; + bool ReplaceAutoMonadNode(const FuncGraphManagerPtr &manager) const; bool DoReplace(const FuncGraphManagerPtr manager, const std::vector &order_group, std::unordered_map> *groups) const; }; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index 4978338271..80170019e7 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -23,6 +23,8 @@ #include "frontend/optimizer/irpass/grad_var_prepare.h" #include "frontend/optimizer/irpass/gradient_eliminate.h" #include "frontend/optimizer/irpass/inline.h" +#include "frontend/optimizer/irpass/updatestate_eliminate.h" +#include "frontend/optimizer/irpass/stopgrad_eliminate.h" #include "frontend/optimizer/irpass/incorporate_call.h" #include "frontend/optimizer/irpass/incorporate_getitem.h" #include "frontend/optimizer/irpass/item_tuple_or_list_eliminate.h" @@ -65,6 +67,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { MakeSubstitution(std::make_shared(), "zero_like_fill_zero", prim::kPrimZerosLike); adjust_all_reduce_mul_add_ = MakeSubstitution(std::make_shared(), "adjust_all_reduce_mul_add", prim::kPrimAddN); + float_depend_g_call_ = MakeSubstitution(std::make_shared(), "float_depend_g_call", IsCNodeDup); // ops eliminate item_tuple_or_list_eliminate_ = MakeSubstitution( @@ -128,6 +131,8 @@ OptimizeIRPassLib::OptimizeIRPassLib() { MakeSubstitution(std::make_shared(), "float_env_getitem_switch", prim::kPrimEnvGetItem); convert_switch_replacement_ = MakeSubstitution(std::make_shared(), "convert_switch_replacement", IsCNodeDup); + exchange_switch_depend_value_ = + MakeSubstitution(std::make_shared(), "exchange_switch_depend_value", prim::kPrimSwitch); // Addn merge_addn_ = MakeSubstitution(std::make_shared(), "merge_addn", prim::kPrimAddN); @@ -145,6 +150,16 @@ OptimizeIRPassLib::OptimizeIRPassLib() { specialize_transform_ = MakeSubstitution(std::make_shared(), "specialize_transform", IsCNodeGraph); + // UpdateState eliminate + updatestate_eliminater_ = + MakeSubstitution(std::make_shared(), "updatestate_eliminater", prim::kPrimUpdateState); + switch_call_monad_eliminater_ = MakeSubstitution(std::make_shared(), + "switch_call_monad_eliminater", IsCNodeDup); + + // StopGradient eliminate + stopgrad_eliminater_ = + MakeSubstitution(std::make_shared(), "stopgrad_eliminater", prim::kPrimStopGradient); + // Incorporation incorporate_getitem_set_ = MakeSubstitution(std::make_shared(), "incorporate_getitem_set", prim::kPrimTupleGetItem); @@ -187,6 +202,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() { std::make_shared(), "row_tensor_eliminate", {prim::kPrimRowTensorGetIndices, prim::kPrimRowTensorGetValues, prim::kPrimRowTensorGetDenseShape}); + // RowTensorAddZerosLike Eliminate + row_tensor_add_zeros_like_ = + MakeSubstitution(std::make_shared(), "row_tensor_add_zeros_like", prim::kPrimRowTensorAdd); + // SparseTensor Eliminate sparse_tensor_eliminate_ = MakeSubstitution( std::make_shared(), "sparse_tensor_eliminate", diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 3aa3d5305c..003e54d939 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -37,7 +37,7 @@ class OptimizeIRPassLib { SubstitutionPtr special_op_eliminate_; SubstitutionPtr zero_like_fill_zero_; SubstitutionPtr adjust_all_reduce_mul_add_; - + SubstitutionPtr float_depend_g_call_; // ops eliminate SubstitutionPtr item_tuple_or_list_eliminate_; SubstitutionPtr tile_eliminate_; @@ -73,6 +73,7 @@ class OptimizeIRPassLib { SubstitutionPtr float_tuple_getitem_switch_; SubstitutionPtr float_env_getitem_switch_; SubstitutionPtr convert_switch_replacement_; + SubstitutionPtr exchange_switch_depend_value_; // AddN SubstitutionPtr merge_addn_; @@ -91,6 +92,11 @@ class OptimizeIRPassLib { SubstitutionPtr replace_applicator_; SubstitutionPtr specialize_transform_; + // Auto-monad related eliminaters. + SubstitutionPtr updatestate_eliminater_; + SubstitutionPtr switch_call_monad_eliminater_; + SubstitutionPtr stopgrad_eliminater_; + // Incorporation SubstitutionPtr incorporate_getitem_set_; SubstitutionPtr incorporate_getitem_from_param_; @@ -122,6 +128,9 @@ class OptimizeIRPassLib { // RowTensor Eliminate SubstitutionPtr row_tensor_eliminate_; + // RowTensorAddZerosLike Eliminate + SubstitutionPtr row_tensor_add_zeros_like_; + // SparseTensor Eliminate SubstitutionPtr sparse_tensor_eliminate_; @@ -177,6 +186,13 @@ inline bool IsParam(const AnfNodePtr &node) { return false; } +inline bool IsLoad(const AnfNodePtr &node) { + if (node == nullptr || !node->isa()) { + return false; + } + return IsPrimitiveCNode(node, prim::kPrimLoad); +} + // Check if CNode Input 0 is Func Graph inline bool IsCNodeGraph(const AnfNodePtr &node) { if (node == nullptr || !node->isa()) { diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc index f04b3b506c..b96d0381e3 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/arithmetic_simplify.cc @@ -83,7 +83,8 @@ AnfNodePtr ArithmeticSimplify2::operator()(const OptimizerPtr &, const AnfNodePt // Multiply by zero MATCH_REPLACE_IF(node, x * zero_, zero_.WithShapeAs(node), - !zero_.CheckFunc(IsParam, node) && x.GetNode(node)->func_graph() == node->func_graph()); + !zero_.CheckFunc(IsParam, node) && !x.CheckFunc(IsLoad, node) && + x.GetNode(node)->func_graph() == node->func_graph()); auto zero_prim = PPrimitive(prim::kPrimZerosLike, y); MATCH_REPLACE_IF(node, x * zero_prim, zero_.WithShapeAs(node), !zero_prim.CheckFunc(IsParam, node) && x.GetNode(node)->func_graph() == node->func_graph()); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc index cf45b890b8..04740e0c7f 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.cc @@ -84,7 +84,7 @@ bool InConvertWhiteList(const AnfNodePtr &node, size_t index) { } } - std::vector adapter_convert_ops = {prim::kPrimDepend, prim::kPrimControlDepend}; + std::vector adapter_convert_ops = {prim::kPrimDepend, prim::kPrimControlDepend, prim::kPrimLoad}; for (auto &item : adapter_convert_ops) { if (IsPrimitiveCNode(node, item)) { return true; @@ -149,6 +149,10 @@ FuncGraphPtr TransformGraphCondBranchNodes( // if the apply input does not belong to graph, insert a switch node for (size_t index = 0; index < inputs.size(); index++) { auto input_node = inputs[index]; + if (HasAbstractMonad(input_node)) { + // Do not guard with switch for monad inputs. + continue; + } MS_EXCEPTION_IF_NULL(input_node); // for some ops input should not guard it with switch if (InConvertWhiteList(node, index)) { diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h index 163d72ddab..30d7ac6f1d 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/branch_culling.h @@ -126,6 +126,14 @@ class ConvertSwitchReplacement : public OptimizerCaller { auto trans_g2 = internal::TransformGraphCondFalseBranchNodes(g2_, x_); std::vector params; + auto cnode = node->cast(); + if (cnode && cnode->size() > 1) { + // There are arguments for the call of switch result, + // usually these are monad states added by auto-monad. + for (size_t i = 1; i < cnode->size(); ++i) { + params.push_back(cnode->inputs().at(i)); + } + } auto fg = node->func_graph(); auto cloned_g1 = InlineClone(trans_g1, fg, params); auto cloned_g2 = InlineClone(trans_g2, fg, params); @@ -141,6 +149,25 @@ class ConvertSwitchReplacement : public OptimizerCaller { return nullptr; } }; + +// {prim::kPrimSwitch, {prim::kPrimDepend, ValueNode, X}, G1, G2} -> +// {prim::kPrimDepend, {prim::kPrimSwitch, ValueNode, G1, G2}, X} +class ExchangeSwitchDependValue : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + ScopePtr scope = node->cast()->scope(); + ScopeGuard scope_guard(scope); + + PatternNode cond, true_br, false_br, v, x; + MATCH_REPLACE_IF(node, PPrimitive(prim::kPrimSwitch, PPrimitive(prim::kPrimDepend, v, x), true_br, false_br), + PPrimitive(prim::kPrimDepend, PPrimitive(prim::kPrimSwitch, v, true_br, false_br), x), + IsVNode(v.GetNode(node))); + return nullptr; + } +}; } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h index c337aad857..1bdc56e4d0 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/env_item_eliminate.h @@ -279,15 +279,35 @@ class EnvGetSetItem : public AnfVisitor { bool is_match_{false}; }; +// {prim::kPrimEnvGetitem, {prim::kPrimDepend, X1, X2}, item, dflt} -> +// {prim::kPrimDepend, {prim::kPrimEnvGetitem, X1, item, dflt}, X2} +class SwapEnvGetItemDepend : public OptimizerCaller { + public: + AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + ScopePtr scope = node->cast()->scope(); + ScopeGuard scope_guard(scope); + + PatternNode x1, x2, item, dflt; + MATCH_REPLACE(node, PPrimitive(prim::kPrimEnvGetItem, PPrimitive(prim::kPrimDepend, x1, x2), item, dflt), + PPrimitive(prim::kPrimDepend, PPrimitive(prim::kPrimEnvGetItem, x1, item, dflt), x2)); + return nullptr; + } +}; + class EnvGetItemEliminater : public OptimizerCaller { public: EnvGetItemEliminater() : new_env_get_item_(std::make_shared()), add_env_get_item_(std::make_shared()), - env_get_set_item_(std::make_shared()) { + env_get_set_item_(std::make_shared()), + swap_env_get_item_depend_(std::make_shared()) { eliminaters_.emplace_back(new_env_get_item_); eliminaters_.emplace_back(add_env_get_item_); eliminaters_.emplace_back(env_get_set_item_); + eliminaters_.emplace_back(swap_env_get_item_depend_); } ~EnvGetItemEliminater() = default; @@ -303,7 +323,7 @@ class EnvGetItemEliminater : public OptimizerCaller { } private: - OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_; + OptimizerCallerPtr new_env_get_item_, add_env_get_item_, env_get_set_item_, swap_env_get_item_depend_; std::vector eliminaters_{}; }; diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.cc b/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.cc index 86cf60f84a..cec6763a51 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass/grad_var_prepare.cc @@ -29,10 +29,11 @@ namespace mindspore { namespace opt { namespace irpass { -static AnfNodePtr GenerateUnpackGraphNode(std::vector inputs_y, FuncGraphPtr func_graph, +static AnfNodePtr GenerateUnpackGraphNode(const AnfNodePtr &origin_node, std::vector inputs_y, AnfNodePtr func_node, bool is_unpack, bool sens_param) { - MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_node); + FuncGraphPtr func_graph = origin_node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); std::vector nodes; AnfNodePtr unpack_graph_node = nullptr; if (is_unpack) { @@ -42,7 +43,7 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector inputs_y, Func // {unpackcall, {GradOperation, ...}, args...} std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes), [](const AnfNodePtr &node) { return node; }); - unpack_graph_node = func_graph->NewCNode(nodes); + unpack_graph_node = func_graph->NewCNodeBefore(origin_node, nodes); } else { auto unpack_graph = std::make_shared("unpack_graph", sens_param, false); nodes.push_back(NewValueNode(unpack_graph)); @@ -50,7 +51,7 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector inputs_y, Func // {{GradOperation, ...}, args...} std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes), [](const AnfNodePtr &node) { return node; }); - unpack_graph_node = func_graph->NewCNode(nodes); + unpack_graph_node = func_graph->NewCNodeBefore(origin_node, nodes); } return unpack_graph_node; } @@ -87,12 +88,17 @@ bool IsMetaFuncGraph(const AnfNodePtr &node, const MetaFuncGraphPtr meta_func_gr // {{GradOperation, g, w}, Ys} // {UnPackCall, {GradOperation, g, w}, Ys} AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) { - if (!node->isa() || node->func_graph() == nullptr) { + auto cnode = node->cast(); + if (cnode == nullptr) { + return nullptr; + } + auto func_graph = cnode->func_graph(); + if (func_graph == nullptr) { return nullptr; } // {{...}, Ys} - auto inputs_y = node->cast()->inputs(); + auto inputs_y = cnode->inputs(); std::vector inputs_x; if (IsCNode(inputs_y[0])) { inputs_x = inputs_y[0]->cast()->inputs(); @@ -122,19 +128,17 @@ AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &no return nullptr; } - AnfNodePtr unpack_graph_node = - GenerateUnpackGraphNode(inputs_y, node->cast()->func_graph(), func_node, - IsMetaFuncGraph(inputs_y[0], unpack_op_), grad_op_ptr->sens_param()); - // constuct new grad_opration - inputs_x[1] = unpack_graph_node; - auto grad_op_cnode = node->func_graph()->NewCNode(inputs_x); + const bool is_unpack = IsMetaFuncGraph(inputs_y[0], unpack_op_); + const bool sens_param = grad_op_ptr->sens_param(); + inputs_x[1] = GenerateUnpackGraphNode(node, inputs_y, func_node, is_unpack, sens_param); + // construct new grad_opration + auto grad_op_cnode = func_graph->NewCNodeBefore(node, inputs_x); if (IsMetaFuncGraph(inputs_y[0], unpack_op_)) { inputs_y[1] = grad_op_cnode; } else { inputs_y[0] = grad_op_cnode; } - auto cnode = node->func_graph()->NewCNode(inputs_y); - return cnode; + return func_graph->NewCNodeBefore(node, inputs_y); } } // namespace irpass } // namespace opt diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h b/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h index 38ca1748e4..2072c08ef3 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/merge_addn.h @@ -77,6 +77,14 @@ class MergeAddN : public AnfVisitor { is_match_ = false; return; } + + MonadState state_input = GetMonadState(inputs[1]); + MonadState state_cnode = GetMonadState(cnode, inputs[1]); + if (!IsStateEquivalent(state_cnode, state_input)) { + is_match_ = false; + return; + } + (void)Ys_.erase(Ys_.begin()); (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_)); (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_)); @@ -90,6 +98,14 @@ class MergeAddN : public AnfVisitor { is_match_ = false; return; } + + MonadState state_input = GetMonadState(inputs.back()); + MonadState state_cnode = GetMonadState(cnode, inputs.back()); + if (!IsStateEquivalent(state_cnode, state_input)) { + is_match_ = false; + return; + } + Ys_.pop_back(); (void)std::copy(Ys_.begin(), Ys_.end(), std::back_inserter(args_)); (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args_)); diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/row_tensor_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/row_tensor_eliminate.h index abfc54327a..6235e71a22 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/row_tensor_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/row_tensor_eliminate.h @@ -42,6 +42,16 @@ class RowTensorEliminater : public OptimizerCaller { return nullptr; } }; + +// {prim::kPrimRowTensorAdd, rowtensor, zeros_like(x)} -> rowtensor +class RowTensorAddZerosLike : public AnfVisitor { + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + PatternNode x, y; + auto zeros_like = PPrimitive(prim::kPrimZerosLike, y); + MATCH_REPLACE(node, PPrimitive(prim::kPrimRowTensorAdd, x, zeros_like), x); + return nullptr; + } +}; } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h index a4e84e9cbf..eed6b04565 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h @@ -362,7 +362,7 @@ class PynativeEliminater : public OptimizerCaller { CheckSymbolVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { auto rep = (arg).GetNode(node); if (rep != nullptr) { - if (rep->isa()) { + if (rep->isa() && !HasAbstractMonad(rep)) { auto value_node = rep->cast(); auto new_value_node = NewValueNode(FillZero(value_node->value())); new_value_node->set_has_new_value(value_node->has_new_value()); @@ -436,12 +436,12 @@ class AllReduceConstElim : public OptimizerCaller { AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { PatternNode x; auto pattern = PPrimitive(prim::kPrimAllReduce, x); - // If AllReduce takes contant value as input and values across devices are all the same(ensured by parallel mode) + // If AllReduce takes constant value as input and values across devices are all the same(ensured by parallel mode) if (pattern.TryCapture(node) && IsVNode(x.GetNode(node)) && (pattern.GetFuncGraph()->has_flag(parallel::AUTO_PARALLEL) || pattern.GetFuncGraph()->has_flag(parallel::SEMI_AUTO_PARALLEL))) { auto cur_func_graph = pattern.GetFuncGraph(); - // If reduce operation is sum, then multiply constant by number of devices, otherwise just return the contant + // If reduce operation is sum, then multiply constant by number of devices, otherwise just return the constant auto prim_cnode = pattern.GetOriginalNode(); MS_EXCEPTION_IF_NULL(prim_cnode); auto primitive = GetCNodePrimitive(prim_cnode); @@ -481,6 +481,37 @@ class AllReduceConstElim : public OptimizerCaller { return nullptr; } }; + +// This pattern introduced by Depend(CollectCNodeWithIsolateNodes) in program_specialize.cc +// {{prim::kPrimDepend, X, Y}, Xs}->{prim::kPrimDepend, {X, Xs}, Y} +class FloatDependGCall : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!node->isa() || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + // as IsCNodeDup had checked the size of inputs must be greater or equal than 1, so no check here. + if (IsPrimitiveCNode(inputs[0], prim::kPrimDepend)) { + auto &depend_inputs = inputs[0]->cast()->inputs(); + if (depend_inputs.size() != 3) { + return nullptr; + } + // put {Y, Xs} to new_inputs; + std::vector new_inputs({depend_inputs[1]}); + new_inputs.insert(new_inputs.end(), inputs.cbegin() + 1, inputs.cend()); + TraceGuard guard(std::make_shared(node->debug_info())); + ScopePtr scope = node->scope(); + ScopeGuard scope_guard(scope); + auto new_call_node = node->func_graph()->NewCNode(new_inputs); + auto new_node = node->func_graph()->NewCNode({depend_inputs[0], new_call_node, depend_inputs[2]}); + return new_node; + } + return nullptr; + } +}; + } // namespace irpass } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/stopgrad_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/stopgrad_eliminate.h new file mode 100644 index 0000000000..69e00d3de8 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/stopgrad_eliminate.h @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_STOPGRAD_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_STOPGRAD_ELIMINATE_H_ + +#include "ir/anf.h" +#include "base/core_ops.h" +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/anf_visitor.h" + +namespace mindspore::opt::irpass { +// +// StopGradientEliminater eliminates redundant stop_gradient nodes. +// +class StopGradientEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &start_node) override { + // We assume that the start_node is a StopGradient cnode. + AnfNodePtr node = start_node; + AnfNodePtr input = nullptr; + while ((input = GetInputStopGradient(node)) != nullptr) { + node = input; + } + if (node != start_node) { + return node; + } + return nullptr; + } + + private: + static inline AnfNodePtr GetInputStopGradient(const AnfNodePtr &node) { + auto &input = node->cast()->inputs().at(1); + if (IsPrimitiveCNode(input, prim::kPrimStopGradient)) { + return input; + } + return nullptr; + } +}; +} // namespace mindspore::opt::irpass + +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_STOPGRAD_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc new file mode 100644 index 0000000000..d2be02d8ae --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.cc @@ -0,0 +1,672 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "frontend/optimizer/irpass/updatestate_eliminate.h" + +#include +#include +#include +#include + +#include "frontend/operator/ops.h" + +namespace mindspore::opt::irpass { +namespace { +// data = Load(input, attach) +// data = Depend(input, attach) +// monad = UpdateState(input, attach) +constexpr size_t kInputIndex = 1; +constexpr size_t kAttachIndex = 2; +constexpr size_t kMakeTupleSize = 3; +constexpr size_t kMinDependSize = 3; +constexpr size_t kAssignSize = 4; +constexpr size_t kAssignMonadInputIndex = 3; + +FuncGraphManagerPtr GetManager(const AnfNodePtr &node) { + auto fg = node->func_graph(); + if (fg == nullptr) { + return nullptr; + } + return fg->manager(); +} + +// Return true if the node is only used by the given update_state node. +bool OnlyUpdateStateUse(const CNodePtr &update_state_node, const AnfNodePtr &node) { + auto mgr = GetManager(update_state_node); + if (mgr == nullptr) { + return false; + } + auto &node_users = mgr->node_users(); + auto iter = node_users.find(node); + if (iter == node_users.end()) { + return false; + } + auto &partial_users = iter->second; + return (partial_users.size() == 1) && (partial_users.front().first == update_state_node); +} + +// Eliminate useless node that only used by associated update_state. +// Convert: +// x1 = node(x, u) +// u1 = update_state(u, x1) # update_state is the only user of node +// user(u1) +// To: +// user(u) +AnfNodePtr EliminateUpdateStateOnlyUsedNode(const CNodePtr &update_state, const AnfNodePtr &node) { + if (!OnlyUpdateStateUse(update_state, node)) { + // Skip if UpdateState is not the only user of cnode. + return nullptr; + } + // Replace UpdateState with the input monad. + return update_state->inputs().at(kInputIndex); +} + +// Eliminate UpdateState that attaches a pure (no-side-effect) node. +// Convert: +// x = pure_node(args) # no side effect +// u1 = update_state(u, x) +// user(u1) +// To: +// x = pure_node(args) +// user(u) +AnfNodePtr EliminateUpdateStateForPureNode(const CNodePtr &update_state, const AnfNodePtr &attach) { + if (IsPrimitiveCNode(attach, prim::kPrimTupleGetItem)) { + // Skip tuple_getitem. + return nullptr; + } + auto cnode = dyn_cast(attach); + if (cnode == nullptr) { + // Skip value node or parameter. + return nullptr; + } + if (cnode->size() > 1) { + // If the last input is a monad, means the attach node has side-effect and + // we should keep UpdateState; otherwise, we will remove the UpdateState. + if (HasAbstractMonad(cnode->inputs().back())) { + return nullptr; + } + } + // Remove UpdateState by replace it with its input monad. + return update_state->inputs().at(kInputIndex); +} + +// Eliminate redundant UpdateState/Depend pair nodes caused by inline. +// Convert: +// x1 = Depend(x, u) +// u1 = UpdateState(u, x1) +// out = x_user(x1) +// u2 = u_user(u1) +// To: +// out = x_user(x) +// u2 = u_user(u) +AnfNodePtr EliminateUpdateStateWithDepend(const CNodePtr &update_state, const CNodePtr &depend) { + auto input_monad = depend->inputs().back(); + if (!HasAbstractMonad(input_monad)) { + // Skip if Depend attach input is not a monad. + return nullptr; + } + auto update_monad = update_state->inputs().at(kInputIndex); + if (!HasAbstractMonad(update_monad)) { + // Skip if UpdateState input is not a monad. + MS_LOG(WARNING) << "Not a monad input: " << update_state->DebugString(); + return nullptr; + } + // Check monad inputs. + const auto &input_monad_abs = *(input_monad->abstract()); + const auto &update_monad_abs = *(update_monad->abstract()); + bool same_monad = (input_monad_abs == update_monad_abs); + if (!same_monad) { + // Skip if they are different monad (one is IO, another is U). + return nullptr; + } + // Now we can eliminate the UpdateState and Depend nodes. + auto mgr = GetManager(update_state); + if (mgr == nullptr) { + return nullptr; + } + // Replace Depend with its input. + if (depend->size() == kMinDependSize) { + auto depend_input = depend->inputs().at(kInputIndex); + mgr->Replace(depend, depend_input); + } else { + auto inputs = depend->inputs(); + inputs.pop_back(); + auto fg = depend->func_graph(); + auto new_depend = fg->NewCNode(inputs); + new_depend->set_abstract(depend->abstract()); + mgr->Replace(depend, new_depend); + } + // Replace UpdateState node with the input monad of Depend. + return input_monad; +} + +// Eliminate useless make_tuple with 'Dead Node'. +// Convert: +// t = make_tuple(input, "Dead Node") +// u = UpdateState(u, t) +// To: +// u = UpdateState(u, input) +AnfNodePtr EliminateMakeTupleWithDeadNode(const CNodePtr &update_state, const CNodePtr &make_tuple) { + if (make_tuple->size() != kMakeTupleSize) { + return nullptr; + } + auto &node = make_tuple->inputs().at(kAttachIndex); + auto node_abs = node->abstract(); + if (node_abs == nullptr || !node_abs->isa()) { + return nullptr; + } + auto fg = update_state->func_graph(); + if (fg == nullptr) { + return nullptr; + } + // Create a new UpdateState to replace the old one. + const auto &attach = make_tuple->inputs().at(kInputIndex); + auto new_update_state = fg->NewCNode({update_state->input(0), update_state->input(1), attach}); + new_update_state->set_abstract(update_state->abstract()); + new_update_state->set_scope(update_state->scope()); + return new_update_state; +} + +// Return true if the function is only used by make_tuple. +bool OnlyMakeTupleUseFunc(const CNodePtr &make_tuple, const AnfNodePtr &func_node) { + auto mgr = GetManager(make_tuple); + if (mgr == nullptr) { + return false; + } + auto &node_users = mgr->node_users(); + auto iter = node_users.find(func_node); + if (iter == node_users.end()) { + return false; + } + auto &partial_users = iter->second; + return (partial_users.size() == 1) && (partial_users.front().first == make_tuple); +} + +// Eliminate UpdateState which the second input is MakeTuple, and the second input of MakeTuple is useless Function. +// Convert: +// t = make_tuple(input, Function) or t = make_tuple(Function, input) +// u2 = UpdateState(u1, t) +// To: +// t = make_tuple(input, Function) or t = make_tuple(Function, input) +// u2 = u1 +AnfNodePtr EliminateUpdateStateWithMakeTupleFunc(const CNodePtr &update_state, const CNodePtr &make_tuple) { + if (make_tuple->size() != kMakeTupleSize) { + return nullptr; + } + auto &first_input = make_tuple->inputs().at(kInputIndex); + if (IsValueNode(first_input) && OnlyMakeTupleUseFunc(make_tuple, first_input)) { + return update_state->input(1); + } + auto &second_input = make_tuple->inputs().at(kAttachIndex); + if (IsValueNode(second_input) && OnlyMakeTupleUseFunc(make_tuple, second_input)) { + return update_state->input(1); + } + return nullptr; +} + +size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector *loads); +size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector *loads); + +// Search consecutive load nodes from UpdateState node. +size_t GetLoadsFromUpdateState(const CNodePtr &update_state, std::vector *loads) { + auto &attach = update_state->inputs().at(kAttachIndex); + if (IsPrimitiveCNode(attach, prim::kPrimLoad)) { + return GetLoadsFollowLoad(attach->cast(), loads); + } + if (IsPrimitiveCNode(attach, prim::kPrimMakeTuple)) { + return GetLoadsFollowTuple(update_state, attach->cast(), loads); + } + return 0; +} + +size_t GetLoadsFollowLoad(const CNodePtr &load, std::vector *loads) { + loads->push_back(load); + auto &load_attach = load->inputs().at(kAttachIndex); + if (IsPrimitiveCNode(load_attach, prim::kPrimUpdateState)) { + return GetLoadsFromUpdateState(load_attach->cast(), loads) + 1; + } + return 1; +} + +size_t GetLoadsFollowTuple(const CNodePtr &update_state, const CNodePtr &make_tuple, std::vector *loads) { + if (!OnlyUpdateStateUse(update_state, make_tuple)) { + // UpdateState should be the only user of + return 0; + } + auto &inputs = make_tuple->inputs(); + bool is_all_load = std::all_of(inputs.begin() + 1, inputs.end(), + [](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); }); + if (!is_all_load) { + // Stop if not all tuple elements are load node. + return 0; + } + // Add load nodes from tuple elements. + for (size_t i = 1; i < inputs.size(); ++i) { + auto &element = inputs.at(i); + loads->push_back(element->cast()); + } + // Follow prev update state if found. + auto prev_node = update_state->inputs().at(kInputIndex); + if (IsPrimitiveCNode(prev_node, prim::kPrimUpdateState)) { + return GetLoadsFromUpdateState(prev_node->cast(), loads) + 1; + } + return 1; +} + +// Create a MakeTuple node before UpdateState for same nodes, if there are more than 1 node used. +AnfNodePtr MakeTupleForSameNodes(const FuncGraphPtr &fg, const CNodePtr &old_update_state, + const AnfNodePtrList &make_tuple_inputs) { + constexpr size_t kOneNodeInputSize = 2; + if (make_tuple_inputs.size() == kOneNodeInputSize) { + // We don't need make_tuple since there is only one load. + return make_tuple_inputs.at(1); + } + // Create MakeTuple cnode. + auto make_tuple = fg->NewCNode(make_tuple_inputs); + // Set abstract for the MakeTuple node. + abstract::AbstractBasePtrList element_abstracts; + std::transform(make_tuple_inputs.begin() + 1, make_tuple_inputs.end(), std::back_inserter(element_abstracts), + [](const AnfNodePtr &input) { return input->abstract(); }); + make_tuple->set_abstract(std::make_shared(element_abstracts)); + make_tuple->set_scope(old_update_state->scope()); + return make_tuple; +} + +// Eliminate UpdateStates for consecutive Loads. +// Convert: +// x1 = Load(x1, u) +// u1 = UpdateState(u, x1) +// x2 = Load(x2, u1) +// u2 = UpdateState(u1, x2) +// ... +// xN = Load(xN, u(N-1)) +// uN = UpdateState(u(N-1), xN) +// To: +// x1 = Load(x1, u) +// x2 = Load(x2, u) +// ... +// xN = Load(xN, u) +// t = make_tuple(x1, x2, ... , xN) +// u1 = UpdateState(u, t) +AnfNodePtr EliminateUpdateStateForLoads(const CNodePtr &old_update_state, const std::vector &loads) { + auto fg = old_update_state->func_graph(); + if (fg == nullptr) { + return nullptr; + } + auto mgr = fg->manager(); + if (mgr == nullptr) { + return nullptr; + } + // Prepare tuple elements from Load nodes. + AnfNodePtrList make_tuple_inputs; + std::set loaded_para_set; + make_tuple_inputs.reserve(loads.size() + 1); + make_tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); + auto input_monad = loads.back()->inputs().at(kAttachIndex); + for (auto iter = loads.rbegin(); iter != loads.rend(); ++iter) { + auto &load = *iter; + auto result = loaded_para_set.emplace(load->inputs().at(kInputIndex)); + const bool is_new_load = result.second; + if (is_new_load) { + // Put Load node as a tuple element, if the parameter is not loaded by other Load. + make_tuple_inputs.emplace_back(load); + } + if (load->inputs().at(kAttachIndex) != input_monad) { + // Set all load use same input monad. + mgr->SetEdge(load, kAttachIndex, input_monad); + } + } + if (make_tuple_inputs.size() == 1) { + // This should not happen. + MS_LOG(WARNING) << "No loads for " << old_update_state->DebugString(2); + return nullptr; + } + // Create the new UpdateState node with a MakeTuple, replace the old UpdateStateNode. + auto attach = MakeTupleForSameNodes(fg, old_update_state, make_tuple_inputs); + auto update_state = NewValueNode(prim::kPrimUpdateState); + auto new_update_state = fg->NewCNode({update_state, input_monad, attach}); + new_update_state->set_abstract(old_update_state->abstract()); + new_update_state->set_scope(old_update_state->scope()); + return new_update_state; +} + +// Eliminate UpdateStates between Assign nodes. +// Covert: +// a1 = Assign(para1, value1, u1) +// u2 = UpdateState(u1, a1) +// a2 = Assign(para2, value2, u2) # para1 != para2, para1 != value2, para2 != value1 +// u3 = UpdateState(u2, a2) +// To: +// a1 = Assign(para1, value1, u1) +// a2 = Assign(para2, value2, u1) +// t = MakeTuple(a1, a2) +// u3 = UpdateState(u1, t) +AnfNodePtr EliminateUpdateStateBetweenAssigns(const CNodePtr &update_state, const AnfNodePtr &assign) { + auto a2_cnode = assign->cast(); + if (a2_cnode->size() != kAssignSize) { + return nullptr; + } + auto para2 = a2_cnode->input(kInputIndex); + auto value2 = a2_cnode->input(kAttachIndex); + auto u2 = a2_cnode->input(kAssignMonadInputIndex); + if (IsPrimitiveCNode(u2, prim::kPrimUpdateState)) { + auto a1 = u2->cast()->input(kAttachIndex); + if (IsPrimitiveCNode(a1, prim::kPrimAssign)) { + auto a1_cnode = a1->cast(); + if (a1_cnode->size() != kAssignSize) { + return nullptr; + } + auto para1 = a1_cnode->input(kInputIndex); + auto value1 = a1_cnode->input(kAttachIndex); + auto u1 = a1_cnode->input(kAssignMonadInputIndex); + if (para1 != para2 && para1 != value2 && para2 != value1) { + auto fg = update_state->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + auto mgr = fg->manager(); + mgr->Replace(u2, u1); + AnfNodePtrList make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), a1, assign}; + auto make_tuple = MakeTupleForSameNodes(fg, update_state, make_tuple_inputs); + auto new_update_state = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u1, make_tuple}); + new_update_state->set_abstract(update_state->abstract()); + new_update_state->set_scope(update_state->scope()); + return new_update_state; + } + } + } + return nullptr; +} + +// Eliminate UpdateStates between MakeTuple and Assign. +// Covert: +// a1 = Assign(para1, value1, u1) +// a2 = Assign(para2, value2, u2) # u2 == u1 +// t = MakeTuple(a1, a2) +// u3 = UpdateState(u1, t) +// a3 = Assign(para3, value3, u3) # para3 != para1, para3 != para2, value3 != para1, value3 != para2 +// # value1 != para3, value2 != para3 +// u4 = UpdateState(u3, a3) +// To: +// a1 = Assign(para1, value1, u1) +// a2 = Assign(para2, value2, u1) +// a3 = Assign(para3, value3, u1) +// t = MakeTuple(a1, a2, a3) +// u4 = UpdateState(u1, t) +AnfNodePtr EliminateUpdateStateBetweenMakeTupleAssign(const CNodePtr &update_state, const AnfNodePtr &assign) { + auto a3_cnode = assign->cast(); + if (a3_cnode->size() != kAssignSize) { + return nullptr; + } + auto para3 = a3_cnode->input(kInputIndex); + auto value3 = a3_cnode->input(kAttachIndex); + auto u3 = a3_cnode->input(kAssignMonadInputIndex); + if (IsPrimitiveCNode(u3, prim::kPrimUpdateState)) { + auto make_tuple = u3->cast()->input(kAttachIndex); + if (IsPrimitiveCNode(make_tuple, prim::kPrimMakeTuple)) { + auto make_tuple_cnode = make_tuple->cast(); + if (make_tuple_cnode->size() != kMakeTupleSize) { + return nullptr; + } + auto a1 = make_tuple_cnode->input(kInputIndex); + auto a2 = make_tuple_cnode->input(kAttachIndex); + if (IsPrimitiveCNode(a1, prim::kPrimAssign) && IsPrimitiveCNode(a2, prim::kPrimAssign)) { + auto a1_cnode = a1->cast(); + auto a2_cnode = a2->cast(); + if (a1_cnode->size() != kAssignSize || a2_cnode->size() != kAssignSize) { + return nullptr; + } + auto para1 = a1_cnode->input(kInputIndex); + auto value1 = a1_cnode->input(kAttachIndex); + auto u1 = a1_cnode->input(kAssignMonadInputIndex); + auto para2 = a2_cnode->input(kInputIndex); + auto value2 = a2_cnode->input(kAttachIndex); + auto u2 = a2_cnode->input(kAssignMonadInputIndex); + bool replace_judge = (u1 == u2) && (para1 != para3) && (para1 != value3) && (para2 != para3) && + (para2 != value3) && (value1 != para3) && (value2 != para3); + if (replace_judge) { + auto fg = update_state->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + auto mgr = fg->manager(); + MS_EXCEPTION_IF_NULL(mgr); + mgr->Replace(u3, u1); + AnfNodePtrList new_make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), make_tuple_cnode->input(kInputIndex), + make_tuple_cnode->input(kAttachIndex), assign}; + auto new_make_tuple = MakeTupleForSameNodes(fg, update_state, new_make_tuple_inputs); + mgr->Replace(make_tuple, new_make_tuple); + auto new_update_state = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u1, new_make_tuple}); + new_update_state->set_abstract(update_state->abstract()); + new_update_state->set_scope(update_state->scope()); + return new_update_state; + } + } + } + } + return nullptr; +} + +// Eliminate UpdateStates between Assign and MakeTuple. +// Covert: +// a1 = Assign(para1, value1, u1) +// u2 = UpdateState(u1_1, a1) # u1_1 == u1 +// a2 = Assign(para2, value2, u2) +// a3 = Assign(para3, value3, u3) # u2 == u3 +// t = MakeTuple(a2, a3) +// u4 = UpdateState(u3, t) # para3 != para1, para3 != para2, value3 != para1, value3 != para2 +// # value1 != para3, value1 != para3 +// To: +// a1 = Assign(para1, value1, u1) +// a2 = Assign(para2, value2, u1) +// a3 = Assign(para3, value3, u1) +// t = MakeTuple(a1, a2, a3) +// u4 = UpdateState(u1, t) +AnfNodePtr EliminateUpdateStateBetweenAssignMakeTuple(const CNodePtr &update_state, const AnfNodePtr &make_tuple) { + auto make_tuple_cnode = make_tuple->cast(); + if (make_tuple_cnode->size() != kMakeTupleSize) { + return nullptr; + } + auto a2 = make_tuple_cnode->input(kInputIndex); + auto a3 = make_tuple_cnode->input(kAttachIndex); + if (IsPrimitiveCNode(a2, prim::kPrimAssign) && IsPrimitiveCNode(a3, prim::kPrimAssign)) { + auto a2_cnode = a2->cast(); + auto a3_cnode = a3->cast(); + if (a2_cnode->size() != kAssignSize || a3_cnode->size() != kAssignSize) { + return nullptr; + } + auto para2 = a2_cnode->input(kInputIndex); + auto value2 = a2_cnode->input(kAttachIndex); + auto u2 = a2_cnode->input(kAssignMonadInputIndex); + if (!IsPrimitiveCNode(u2, prim::kPrimUpdateState)) { + return nullptr; + } + auto para3 = a3_cnode->input(kInputIndex); + auto value3 = a3_cnode->input(kAttachIndex); + auto u3 = a3_cnode->input(kAssignMonadInputIndex); + if (u2 == u3) { + auto u2_cnode = u2->cast(); + MS_EXCEPTION_IF_NULL(u2_cnode); + auto u1 = u2_cnode->input(kInputIndex); + auto a1 = u2_cnode->input(kAttachIndex); + if (IsPrimitiveCNode(a1, prim::kPrimAssign)) { + auto a1_cnode = a1->cast(); + MS_EXCEPTION_IF_NULL(a1_cnode); + if (a1_cnode->size() != kAssignSize) { + return nullptr; + } + auto para1 = a1_cnode->input(kInputIndex); + auto value1 = a1_cnode->input(kAttachIndex); + auto u1_1 = a1_cnode->input(kAssignMonadInputIndex); + bool replace_judge = (u1 == u1_1) && (para1 != para2) && (para1 != para3) && (para1 != value2) && + (para1 != value3) && (para2 != value1) && (para3 != value1); + if (replace_judge) { + auto fg = update_state->func_graph(); + MS_EXCEPTION_IF_NULL(fg); + auto mgr = fg->manager(); + mgr->Replace(u2, u1); + AnfNodePtrList new_make_tuple_inputs{NewValueNode(prim::kPrimMakeTuple), a1, + make_tuple_cnode->input(kInputIndex), + make_tuple_cnode->input(kAttachIndex)}; + auto new_make_tuple = MakeTupleForSameNodes(fg, update_state, new_make_tuple_inputs); + mgr->Replace(make_tuple, new_make_tuple); + auto new_update_state = fg->NewCNode({NewValueNode(prim::kPrimUpdateState), u1, new_make_tuple}); + new_update_state->set_abstract(update_state->abstract()); + new_update_state->set_scope(update_state->scope()); + return new_update_state; + } + } + } + } + return nullptr; +} + +} // namespace + +AnfNodePtr UpdatestateEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + auto update_state_node = dyn_cast(node); + if (update_state_node == nullptr || update_state_node->inputs().empty()) { + MS_LOG(WARNING) << "UpdatestateEliminater encounter invalid node: " << node->DebugString(); + return nullptr; + } + auto &attach = update_state_node->inputs().at(kAttachIndex); + if (IsPrimitiveCNode(attach, prim::kPrimDepend)) { + return EliminateUpdateStateWithDepend(update_state_node, attach->cast()); + } + if (IsPrimitiveCNode(attach, prim::kPrimPartial)) { + return EliminateUpdateStateOnlyUsedNode(update_state_node, attach); + } + const bool attach_is_load = IsPrimitiveCNode(attach, prim::kPrimLoad); + if (attach_is_load) { + auto new_node = EliminateUpdateStateOnlyUsedNode(update_state_node, attach); + if (new_node != nullptr) { + return new_node; + } + // We should continue check when useless Load not found, + // since GetLoadsFromUpdateState() also need to check Load. + } + + const bool attach_is_assign = IsPrimitiveCNode(attach, prim::kPrimAssign); + if (attach_is_assign) { + auto new_node = EliminateUpdateStateBetweenAssigns(update_state_node, attach); + if (new_node != nullptr) { + return new_node; + } + new_node = EliminateUpdateStateBetweenMakeTupleAssign(update_state_node, attach); + if (new_node != nullptr) { + return new_node; + } + } + + const bool attach_is_tuple = IsPrimitiveCNode(attach, prim::kPrimMakeTuple); + if (attach_is_tuple) { + auto new_node = EliminateMakeTupleWithDeadNode(update_state_node, attach->cast()); + if (new_node != nullptr) { + return new_node; + } + // We should continue check when MakeTuple with "Dead Node" not found, + // since GetLoadsFromUpdateState() also need to check MakeTuple. + + new_node = EliminateUpdateStateWithMakeTupleFunc(update_state_node, attach->cast()); + if (new_node != nullptr) { + return new_node; + } + + new_node = EliminateUpdateStateBetweenAssignMakeTuple(update_state_node, attach->cast()); + if (new_node != nullptr) { + return new_node; + } + } + std::vector loads; + if (GetLoadsFromUpdateState(update_state_node, &loads) > 1 && loads.size() > 1) { + return EliminateUpdateStateForLoads(update_state_node, loads); + } + // Eliminate UpdateStates that attaches a no-side-effect node. + if (!attach_is_load && !attach_is_tuple) { + return EliminateUpdateStateForPureNode(update_state_node, attach); + } + return nullptr; +} + +// Eliminate Monad parameter for switch call. +// Convert: +// x = Load(x, u) +// u = UpdateState(u, x) +// ... +// g1 = Partial(...) +// g2 = Partial(...) +// s = switch(cond, g1, g2) +// res = s(u) +// To: +// x = Load(x, u) +// u = UpdateState(u, x) +// ... +// g1 = Partial(..., u) +// g2 = Partial(..., u) +// s = switch(cond, g1, g2) +// res = s() +AnfNodePtr EliminateMonadParameterForSwitchCall(const AnfNodePtr &node) { + const CNodePtr &switch_call = dyn_cast(node); + if (switch_call == nullptr) { + return nullptr; + } + auto fg = switch_call->func_graph(); + if (fg == nullptr) { + return nullptr; + } + auto mgr = fg->manager(); + if (mgr == nullptr) { + return nullptr; + } + if (switch_call->inputs().size() < 2) { + return nullptr; + } + constexpr size_t primary_index = 0; + auto switch_node = switch_call->input(primary_index); + if (!IsPrimitiveCNode(switch_node, prim::kPrimSwitch)) { + return nullptr; + } + MS_LOG(DEBUG) << "Found switch call with monad parameter, " << switch_call->DebugString(); + auto switch_cnode = dyn_cast(switch_node); + if (switch_cnode == nullptr) { + MS_LOG(EXCEPTION) << "switch node cast to CNode failed, " << switch_node->DebugString(); + } + constexpr size_t condition_index = 1; + constexpr size_t first_fg_index = 2; + constexpr size_t second_fg_index = 3; + auto fg1_node = switch_cnode->input(first_fg_index); + auto fg2_node = switch_cnode->input(second_fg_index); + auto build_partial = [&fg, &switch_call](const AnfNodePtr &node) { + CNodePtr new_node; + if (IsPrimitiveCNode(node, prim::kPrimPartial)) { // Node is already Partial CNode. + new_node = fg->NewCNode(node->cast()->inputs()); + } else { // Node is FuncGraph ValueNode. + new_node = fg->NewCNode({NewValueNode(prim::kPrimPartial), node}); + } + constexpr size_t args_start_index = 1; + for (size_t i = args_start_index; i < switch_call->inputs().size(); i++) { + new_node->add_input(switch_call->input(i)); + } + return new_node; + }; + fg1_node = build_partial(fg1_node); + fg2_node = build_partial(fg2_node); + auto cond = switch_cnode->input(condition_index); + auto new_switch_cnode = fg->NewCNode({NewValueNode(prim::kPrimSwitch), cond, fg1_node, fg2_node}); + auto new_switch_call = fg->NewCNode({new_switch_cnode}); + return new_switch_call; +} + +AnfNodePtr SwitchCallMonadParameterEliminater::operator()(const OptimizerPtr &, const AnfNodePtr &node) { + return EliminateMonadParameterForSwitchCall(node); +} +} // namespace mindspore::opt::irpass diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.h new file mode 100644 index 0000000000..1e61459cc7 --- /dev/null +++ b/mindspore/ccsrc/frontend/optimizer/irpass/updatestate_eliminate.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_UPDATESTATE_ELIMINATE_H_ +#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_UPDATESTATE_ELIMINATE_H_ + +#include "frontend/optimizer/optimizer.h" +#include "frontend/optimizer/anf_visitor.h" + +namespace mindspore::opt::irpass { +// +// UpdatestateEliminater eliminates redundant UpdateState related nodes. +// +class UpdatestateEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; +}; + +// +// SwitchCallMonadParameterEliminater eliminates Monad parameter in switch call. +// +class SwitchCallMonadParameterEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override; +}; +} // namespace mindspore::opt::irpass + +#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_UPDATESTATE_ELIMINATE_H_ diff --git a/mindspore/ccsrc/frontend/optimizer/opt.cc b/mindspore/ccsrc/frontend/optimizer/opt.cc index b87260c3f3..5a4c56cc20 100644 --- a/mindspore/ccsrc/frontend/optimizer/opt.cc +++ b/mindspore/ccsrc/frontend/optimizer/opt.cc @@ -141,6 +141,8 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNo #ifdef ENABLE_PROFILE double t = GetTime(); #endif + MS_LOG(DEBUG) << "transform: " << transform->name_ << " will replace: " << node->DebugString() + << " with: " << ret->DebugString(); (void)manager->Replace(node, ret); #ifdef ENABLE_PROFILE MsProfile::StatTime("replace." + transform->name_, GetTime() - t); @@ -205,6 +207,14 @@ bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const Optimize changes = changes || change; loop = loop || change; + // apply transform on isolate nodes. + auto &isolate_nodes = manager->isolate_nodes(); + for (auto &node : isolate_nodes) { + change = ApplyTransform(optimizer, node, list_[i]); + changes = changes || change; + loop = loop || change; + } + // record the status of each transform static const auto enable_dump_pass_ir = (common::GetEnv("ENV_DUMP_PASS_IR") == "1"); if (enable_dump_pass_ir && MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc index e3d30fbb7d..a3d04c56e9 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/node_info.cc @@ -48,6 +48,13 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { return param_value->requires_grad(); } +AnfNodePtr GetRealInput(const AnfNodePtr &input) { + if (IsPrimitiveCNode(input, prim::kPrimLoad)) { + return input->cast()->input(1); + } + return input; +} + // Given the node, return whether each input is a parameter or a output of a operator. // The returned boolean vector should be the same order of the inputs, thus its implementation // is closely consistent with ExtractShape() in step_parallel.cc @@ -70,12 +77,13 @@ std::vector ExtractInputParameterByNode(const CNodePtr &node) { node_inputs = node_inputs[1]->cast()->inputs(); } for (size_t i = 1; i < node_inputs.size(); ++i) { - auto input = node_inputs[i]; + auto input = GetRealInput(node_inputs[i]); if (input->isa()) { auto input_parameter = input->cast(); is_parameter.push_back(ParameterRequireGrad(input_parameter)); - } else if (input->isa() || IsValueNode(input) || IsValueNode(input)) { + } else if ((input->isa() && !HasAbstractMonad(input)) || IsValueNode(input) || + IsValueNode(input)) { is_parameter.push_back(false); } } @@ -174,7 +182,8 @@ std::vector ExtractInputTypeLengthByNode(const CNodePtr &node) { MS_LOG(EXCEPTION) << "Find parameter by ref key node failed"; } inputs_type_len.push_back(GetInputsTypeLen(parameters[0])); - } else if (input->isa() || input->isa() || IsValueNode(input)) { + } else if ((input->isa() && !HasAbstractMonad(input)) || input->isa() || + IsValueNode(input)) { // extract input shape from parameter and apply node inputs_type_len.push_back(GetInputsTypeLen(input)); } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc index 6d0d84178a..b26d3a497e 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/dropout_do_mask_info.cc @@ -177,6 +177,16 @@ Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) { return SUCCESS; } +size_t GetNonMonadInputSize(const CNodePtr &cnode) { + size_t cnode_non_monad_size = cnode->size(); + for (auto &input : cnode->inputs()) { + if (HasAbstractMonad(input)) { + cnode_non_monad_size--; + } + } + return cnode_non_monad_size; +} + PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { @@ -190,7 +200,8 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) { } auto dropout_gen_mask_cnode = dropout_gen_mask->cast(); - if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { + size_t cnode_non_monad_size = GetNonMonadInputSize(dropout_gen_mask_cnode); + if (cnode_non_monad_size != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE; } if (!IsValueNode(dropout_gen_mask_cnode->input(0))) { @@ -220,7 +231,8 @@ void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) { } auto dropout_gen_mask_cnode = dropout_gen_mask->cast(); - if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { + size_t cnode_non_monad_size = GetNonMonadInputSize(dropout_gen_mask_cnode); + if (cnode_non_monad_size != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) { MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index ff889e6eb3..8f6b4e80e0 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -374,6 +374,8 @@ constexpr char EMBED[] = "embed"; constexpr char CREATINSTANCE[] = "create_instance"; constexpr char REF_TO_EMBED[] = "RefToEmbed"; constexpr char STOP_GRADIENT[] = "stop_gradient"; +constexpr char UPDATESTATE[] = "UpdateState"; +constexpr char LOAD[] = "Load"; // Batch parallel black list constexpr char TENSOR_SCATTER_UPDATE[] = "TensorScatterUpdate"; diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc index 7e46af83ba..e56ab8b082 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc @@ -209,8 +209,7 @@ ForwardOp CreateReduceMeanForwardOp(const std::vector &forward_group, con OperatorName operator1_name = REAL_DIV; std::vector device_list = forward_group[0].GetDevicesList(); auto divisor = static_cast(device_list.size()); - std::vector tensor_data = {divisor}; - mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(tensor_data, dtype); + mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(divisor, dtype); ValuePtr op1_param_value = MakeValue(tensor_ptr); Attr op1_param = std::make_pair("divisor", op1_param_value); OperatorParams operator1_params = {std::make_pair(op1_param, 2)}; diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc index 7e490c63fa..6c4c4f58bd 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.cc @@ -200,16 +200,37 @@ std::pair PipelineTransformer::GetOpInfo(const A return std::make_pair(op_info, std::make_shared(tensor_info)); } -std::pair PipelineTransformer::GetParameterPair(const AnfNodePtr &node) { +CNodePtr PipelineTransformer::HandleMonadLoad(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto node_users = manager_->node_users()[node]; + auto &node_users = manager_->node_users()[node]; for (auto &user_pair : node_users) { auto user_node = user_pair.first->cast(); MS_EXCEPTION_IF_NULL(user_node); - if (!IsPipelineCareNode(user_node)) { - continue; + if (IsPipelineCareNode(user_node)) { + return user_node; + } + } + return nullptr; +} + +std::pair PipelineTransformer::GetParameterPair(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + auto &node_users = manager_->node_users()[node]; + for (auto &user_pair : node_users) { + auto care_node = user_pair.first; + auto care_cnode = care_node->cast(); + if (IsPrimitiveCNode(care_node, prim::kPrimLoad)) { + care_cnode = HandleMonadLoad(care_node); + if (!care_cnode) { + continue; + } + } else { + if (!IsPipelineCareNode(care_cnode)) { + continue; + } } - auto op_info = CreateOpInfo(user_node); + MS_EXCEPTION_IF_NULL(care_cnode); + auto op_info = CreateOpInfo(care_cnode); MS_EXCEPTION_IF_NULL(op_info); auto tensor_info = op_info->inputs_tensor_info()[IntToSize(user_pair.second) - 1]; return std::make_pair(nullptr, std::make_shared(tensor_info)); @@ -334,13 +355,22 @@ static std::pair GetShapeType(const AnfNodePtr &node, con return std::make_pair(shape_list, dtype); } +AnfNodePtr PipelineTransformer::HandleMonadDepend(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (IsPrimitiveCNode(node, prim::kPrimDepend)) { + auto cnode = node->cast(); + return HandleMonadDepend(cnode->input(1)); + } + return node; +} + AnfNodePtr PipelineTransformer::FindPipelineCareNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (IsValueNode(cnode->input(0))) { auto graph = GetValueNode(cnode->input(0)); - auto output = graph->output(); + auto output = HandleMonadDepend(graph->output()); MS_EXCEPTION_IF_NULL(output); if (output->isa()) { return output; diff --git a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h index 694c049b11..1d441a67ca 100644 --- a/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h +++ b/mindspore/ccsrc/frontend/parallel/pipeline_transformer/pipeline_transformer.h @@ -72,6 +72,8 @@ class PipelineTransformer { const std::vector &out_input); AnfNodePtr FindPipelineCareNode(const AnfNodePtr &node); std::pair GetOpInfo(const AnfNodePtr &node); + AnfNodePtr HandleMonadDepend(const AnfNodePtr &node); + CNodePtr HandleMonadLoad(const AnfNodePtr &node); std::pair GetParameterPair(const AnfNodePtr &node); OperatorInfoPtr CreateOpInfo(const CNodePtr &cnode); bool IsPipelineCareNode(const CNodePtr &cnode); diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index a50ca6e463..c6ab5f1984 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -819,6 +819,9 @@ void ReshapeCostCompute(const std::vector &all_nodes) { MS_ASSERT(cnode->inputs().size() == 3); // get previous node's strategy_cost_ auto pre_node = cnode->input(1); + if (IsPrimitiveCNode(pre_node, prim::kPrimLoad)) { + pre_node = pre_node->cast()->input(1); + } int64_t out_index = 0; OperatorInfoPtr pre_operator_info; std::vector> pre_stra_costs; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index d4dd5352e1..d5bb19c3f1 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -55,7 +55,7 @@ using mindspore::tensor::Tensor; namespace mindspore { namespace parallel { static const std::set COMMUNICATION_OPS = {ALL_REDUCE, ALL_GATHER, ALL_TO_ALL, REDUCE_SCATTER}; -static const std::set INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS}; +static const std::set INVALID_LOSS_OPS = {GET_NEXT, VIRTUALLOSS, LOAD, UPDATESTATE}; // g_RefMap, for CNode B input i is a RefKey[Parameter C], // it will be one item in map with key: C, and value: (B, i) static std::map> g_RefMap; @@ -619,7 +619,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ MS_EXCEPTION_IF_NULL(prim_anf_node); PrimitivePtr node_prim = prim_anf_node->value()->cast(); MS_EXCEPTION_IF_NULL(node_prim); - if (node_prim->name() == DEPEND && node_pair.second != 1) { + if ((node_prim->name() == DEPEND && node_pair.second != 1) || node_prim->name() == UPDATESTATE) { continue; } if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data()) { @@ -803,6 +803,9 @@ void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { std::string instance_name = CreateInstanceName(node, 0); std::vector replace_input; replace_input = ReplaceOpInput(replace_op, instance_name, node); + if (node->inputs().size() == DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { + replace_input.push_back(node->input(3)); + } CNodePtr replace_node = func_graph->NewCNode(replace_input); MS_EXCEPTION_IF_NULL(replace_node); ScopePtr scope = node->scope(); @@ -1000,7 +1003,7 @@ std::pair FindParameter(const AnfNodePtr &node, const FuncGrap for (size_t index = 0; index < cnode->inputs().size(); ++index) { PrimitivePtr prim = prim_anf_node->value()->cast(); MS_EXCEPTION_IF_NULL(prim); - if (prim->name() == DEPEND && index != 1) { + if ((prim->name() == DEPEND || prim->name() == LOAD) && index != 1) { continue; } if (!FindParameter(cnode->input(index), func_graph).first) { @@ -1104,7 +1107,11 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons MS_EXCEPTION_IF_NULL(func_graph); FuncGraphManagerPtr manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); - + for (auto input : node->inputs()) { + if (input->isa() && HasAbstractMonad(input)) { + node_size--; + } + } if ((node->inputs().size() == 2) && (IsValueNode(node->input(1)))) { MS_LOG(INFO) << "Input is ValueList, skip it."; return; @@ -1417,7 +1424,8 @@ std::vector ExtractShape(const CNodePtr &node) { std::pair node_pair = std::make_pair(node, SizeToLong(i)); g_RefMap[parameters[0]] = node_pair; input_shapes = GetRefKeyNodeShape(input, func_graph); - } else if (IsValueNode(input) || input->isa() || input->isa() || + } else if ((input->isa() && !HasAbstractMonad(input)) || IsValueNode(input) || + input->isa() || ((IsValueNode(input) || IsValueNode(input)) && (inputs_size == 2))) { input_shapes = GetNodeShape(input); } else { @@ -2017,6 +2025,13 @@ std::shared_ptr FindParameterNextLayout(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(manager); AnfNodeIndexSet node_set = manager->node_users()[node]; for (auto &node_pair : node_set) { + if (IsPrimitiveCNode(node_pair.first, prim::kPrimLoad)) { + auto layout_param = FindParameterNextLayout(node_pair.first); + if (!layout_param) { + continue; + } + return layout_param; + } CNodePtr use_apply = node_pair.first->cast(); if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { continue; @@ -2109,7 +2124,8 @@ std::shared_ptr FindPrevLayout(const AnfNodePtr &node) { if (!IsValueNode(cnode->input(0))) { return nullptr; } - if (IsParallelCareNode(cnode) && cnode->has_user_data()) { + if (IsParallelCareNode(cnode) && cnode->has_user_data() && + !IsPrimitiveCNode(node, prim::kPrimReshape)) { auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); if (!layout_ptr) { MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; @@ -3027,13 +3043,26 @@ ParameterUsersInfo FindParameterNodeUsers(const AnfNodePtr &node, bool (*IsCareN auto candidate_set = node->func_graph()->manager()->node_users()[node]; for (auto &candidate : candidate_set) { auto candidate_node = candidate.first; - auto c = candidate_node->cast(); - if (c == nullptr || !c->has_user_data() || IsSomePrimitive(c, RECEIVE)) { - continue; + if (IsPrimitiveCNode(candidate_node, prim::kPrimLoad)) { + if (candidate.second != 1) { + continue; + } + auto load_node_users = node->func_graph()->manager()->node_users()[candidate_node]; + for (auto &node_user : load_node_users) { + auto cnode = node_user.first->cast(); + if (cnode == nullptr || !cnode->has_user_data() || IsSomePrimitive(cnode, RECEIVE)) { + continue; + } + (void)parameter_user_info.second.second.insert(node_user); + } + } else { + auto c = candidate_node->cast(); + if (c == nullptr || !c->has_user_data() || IsSomePrimitive(c, RECEIVE)) { + continue; + } + (void)parameter_user_info.second.second.insert(candidate); } - (void)parameter_user_info.second.second.insert(candidate); } - parameter_user_info.first = node->cast()->name(); parameter_user_info.second.first = node; return parameter_user_info; diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 88ce65b3cb..ce8e692d6b 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -31,6 +31,7 @@ #include "pipeline/jit/pass.h" #include "pipeline/jit/parse/parse_base.h" #include "pipeline/jit/parse/data_converter.h" +#include "pipeline/jit/static_analysis/auto_monad.h" #include "abstract/abstract_value.h" #include "pipeline/jit/static_analysis/static_analysis.h" #include "pipeline/jit/static_analysis/program_specialize.h" @@ -106,7 +107,9 @@ FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, MsProfile::StatTime("renormalize.infer", t2 - t1); MsProfile::StatTime("renormalize.specialize", t3 - t2); #endif + MS_LOG(DEBUG) << "Renormalize end"; + return ret; } @@ -167,11 +170,11 @@ bool CombineLikeGraphs(const ResourcePtr &res) { auto base_graph = cloner->cloned_func_graph()[fg]; MS_LOG(DEBUG) << "Basegraph:" << base_graph->ToString(); - if (fg->paramter_obj_nodes().size() == 0 || graphs.size() <= 1) { + if (fg->used_global_parameters().empty() || graphs.size() <= 1) { continue; } auto &cloned_nodes = *cloner->cloned_node(); - for (auto &fv : fg->paramter_obj_nodes()) { + for (auto &fv : fg->used_global_parameters()) { TraceGuard guard(std::make_shared(fv->debug_info())); auto param = base_graph->add_parameter(); auto &node_users = res->manager()->node_users()[fv]; @@ -185,10 +188,10 @@ bool CombineLikeGraphs(const ResourcePtr &res) { repl_n->set_input(n.second, param); } } - MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size(); + MS_LOG(DEBUG) << "Fg0 used_global_parameters size :" << fg->used_global_parameters().size(); for (auto &g : graphs) { - auto fvs = g->paramter_obj_nodes(); + auto &fvs = g->used_global_parameters(); std::vector new_node_inputs; new_node_inputs.push_back(NewValueNode(base_graph)); for (auto &p : g->parameters()) { @@ -196,7 +199,7 @@ bool CombineLikeGraphs(const ResourcePtr &res) { new_node_inputs.push_back(para_after_cast); } (void)new_node_inputs.insert(new_node_inputs.end(), fvs.begin(), fvs.end()); - AnfNodePtr out = g->NewCNode(new_node_inputs); + AnfNodePtr out = g->NewCNodeBefore(g->get_return(), new_node_inputs); g->set_output(out); MS_LOG(DEBUG) << "Combine graph newout:" << out->DebugString(4); } @@ -209,21 +212,33 @@ bool SymbolResolveAction(const ResourcePtr &res) { if (res->manager() == nullptr) { MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null"; } - if (res->func_graph() == nullptr) { + auto func_graph = res->func_graph(); + if (func_graph == nullptr) { MS_LOG(EXCEPTION) << "SymbolResolve error, graph is null"; } - FuncGraphPtr func_graph = res->func_graph(); - auto succ = parse::ResolveFuncGraph(func_graph, res); - + bool ret = parse::ResolveFuncGraph(func_graph, res); // Remove unused nodes in cnode order list. - func_graph->EraseUnusedNodeInOrder(); - func_graph->ReleaseFullOrderToEffectOrder(); - for (auto fg : func_graph->func_graphs_used_total()) { - MS_EXCEPTION_IF_NULL(fg); - fg->EraseUnusedNodeInOrder(); - fg->ReleaseFullOrderToEffectOrder(); - } - return succ; + if (func_graph) { + func_graph->EraseUnusedNodeInOrder(); + for (auto fg : func_graph->func_graphs_used_total()) { + if (fg) { + fg->EraseUnusedNodeInOrder(); + } + } + } + return ret; +} + +bool AutoMonadAction(const ResourcePtr &res) { + if (res->manager() == nullptr) { + MS_LOG(EXCEPTION) << "Auto-Monad failed, manager is null"; + } + auto func_graph = res->func_graph(); + if (func_graph == nullptr) { + MS_LOG(EXCEPTION) << "Auto-Monad failed, graph is null"; + } + (void)pipeline::AutoMonad(func_graph); + return true; } bool InferenceOptPrepareAction(const ResourcePtr &res) { @@ -270,6 +285,16 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { FuncGraphPtr new_fg = ProgramSpecialize(res, result.context->func_graph(), result.context); res->set_func_graph(new_fg); + // Remove unused nodes in cnode order list, this is prepared for auto-monad. + if (new_fg) { + new_fg->EraseUnusedNodeInOrder(); + for (auto fg : new_fg->func_graphs_used_total()) { + if (fg) { + fg->EraseUnusedNodeInOrder(); + } + } + } + MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true); return true; } @@ -447,7 +472,7 @@ bool StartPSSchedulerAction(const ResourcePtr &res) { #endif // The parallel primitive related valuenode might be partitioned so that its value changes by device, -// that will result in a syncronization error due to different executing order. +// that will result in a synchronization error due to different executing order. // Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive, // the final solution will be proposed later as a parallel feature. bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &res) { @@ -558,6 +583,7 @@ static std::vector CommonPipeline() { // Resolve the python func actions.emplace_back(std::make_pair("symbol_resolve", SymbolResolveAction)); + auto multi_graphs = parallel::CostModelContext::GetInstance()->is_multi_subgraphs(); if (!multi_graphs) { actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs)); @@ -566,6 +592,8 @@ static std::vector CommonPipeline() { actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction)); // Evaluate type and shape, and specialize actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction)); + // Auto-monad for side-effects handling. + actions.emplace_back(std::make_pair("auto_monad", AutoMonadAction)); // Do data structure simplifications and inline actions.emplace_back(std::make_pair("inline", OptInlineAction)); // Add pre-ad, post-inline python pass stub diff --git a/mindspore/ccsrc/pipeline/jit/action.h b/mindspore/ccsrc/pipeline/jit/action.h index 231d62b23b..e00abae37a 100644 --- a/mindspore/ccsrc/pipeline/jit/action.h +++ b/mindspore/ccsrc/pipeline/jit/action.h @@ -32,6 +32,7 @@ using ActionItem = std::pair>; bool ParseAction(const ResourcePtr &res); bool SymbolResolveAction(const ResourcePtr &res); +bool AutoMonadAction(const ResourcePtr &res); bool AbstractSpecializeAction(const ResourcePtr &res); bool GeOptimizeAction(const ResourcePtr &res); bool VmOptimizeAction(const ResourcePtr &res); diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index ee9abb319a..14b219d3d0 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -415,6 +415,10 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature converted = obj.cast(); } else if (py::isinstance(obj)) { converted = obj.cast(); + } else if (py::isinstance(obj)) { + converted = obj.cast(); + } else if (py::isinstance(obj)) { + converted = obj.cast(); } else if (py::isinstance(obj)) { auto env = obj.cast>(); converted = env; diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc index 9225ccc8e1..e23c7e97b1 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.cc @@ -37,18 +37,54 @@ FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) { void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); } +static bool CanBeIsolateNode(const std::string &var_name, const AnfNodePtr &node) { + auto cnode = dyn_cast(node); + if (cnode == nullptr || cnode->inputs().empty()) { + // Not a valid cnode, can not be isolate node. + return false; + } + auto prim = GetValueNode(cnode->inputs().at(0)); + if (prim == nullptr) { + // Not a primitive cnode, it may have side effects or not, + // we add it as an isolate node if its name is not '_' or empty. + // this means that code like: + // _ = func_call() + // will be ignored even if func_call() has side effects. + return !var_name.empty() && var_name != "_"; + } + // For primitive cnode, only those with side effects can be isolate nodes. + auto effect_info = GetPrimEffectInfo(prim); + bool has_effects = (effect_info.memory || effect_info.io); + return has_effects; +} + // write variable records the variable name to corresponding node void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) { MS_LOG(DEBUG) << func_graph_->ToString() << " write var " << var_name << " with node " << node->DebugString(); - vars_[var_name] = node; + auto [iter, is_new_name] = vars_.emplace(var_name, std::make_pair(node, false)); + if (!is_new_name) { + // If a cnode variable with same name already existed but not used, + // add it as an isolate node. for example: + // a = print(x) + // a = print(y) + // when we write variable 'a = print(y)', + // the cnode 'print(x)' should added as an isolate node. + if (!iter->second.second && CanBeIsolateNode(var_name, iter->second.first)) { + func_graph_->AddIsolateNode(iter->second.first); + } + iter->second = std::make_pair(node, false); + } } // read variable from predecessors AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { // get var node if it is found - if (vars_.count(var)) { - AnfNodePtr node = vars_[var]; + auto found = vars_.find(var); + if (found != vars_.end()) { + auto &node = found->second.first; MS_EXCEPTION_IF_NULL(node); + // Mark the variable as used. + found->second.second = true; auto iter = resolve_to_removable_phis_.find(node); if (iter != resolve_to_removable_phis_.end()) { return iter->second; @@ -63,7 +99,7 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { MS_EXCEPTION_IF_NULL(block); return block->ReadVariable(var); } else if (prev_blocks_.empty()) { - // get namespace and make Reslove + // get namespace and make Resolve auto it = var_to_resolve_.find(var); if (it != var_to_resolve_.end()) { return it->second; @@ -141,7 +177,7 @@ AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const Symb << ((std::string)resolve_symbol->symbol()); ValueNodePtr module_node = NewValueNode(name_space); ValueNodePtr symbol_node = NewValueNode(resolve_symbol); - auto node = func_graph()->NewCNode({NewValueNode(prim::kPrimResolve), module_node, symbol_node}); + auto node = func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimResolve), module_node, symbol_node}); return node; } @@ -264,13 +300,13 @@ void FunctionBlock::Mature() { // Force the conditIon node to bool using bool operation CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) { TraceGuard trace_guard(std::make_shared(cond->debug_info())); - CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond}); + CNodePtr op_apply_node = func_graph()->NewCNodeInOrder({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond}); return op_apply_node; } CNodePtr FunctionBlock::ForceToWhileCond(const AnfNodePtr &cond) { TraceGuard trace_guard(std::make_shared(cond->debug_info())); - CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation("while_cond"), cond}); + CNodePtr op_apply_node = func_graph()->NewCNodeInOrder({MakeResolveOperation("while_cond"), cond}); return op_apply_node; } @@ -286,11 +322,10 @@ void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node) input_nodes.emplace_back(node); } - CNodePtr jump = func_graph()->NewCNode(input_nodes); + CNodePtr jump = func_graph()->NewCNodeInOrder(input_nodes); jumps_[target_block.get()] = jump; target_block->AddPrevBlock(shared_from_this()); func_graph()->set_output(jump); - InsertDependItemsBeforeReturn(); } // Perform a conditional jump using switch operation. @@ -302,68 +337,56 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); } CNodePtr switch_app = - func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()), - NewValueNode(false_block->func_graph())}); - CNodePtr switch_app_new = func_graph()->NewCNode({switch_app}); + func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), condNode, NewValueNode(true_block->func_graph()), + NewValueNode(false_block->func_graph())}); + CNodePtr switch_app_new = func_graph()->NewCNodeInOrder({switch_app}); func_graph()->set_output(switch_app_new); - InsertDependItemsBeforeReturn(); } -void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) { +// Create cnode for the assign statement like 'self.target = source'. +// convert it to 'P.Assign(self.target, source)' and then add the cnode as isolate node. +void FunctionBlock::SetStateAssign(const AnfNodePtr &target, const AnfNodePtr &source) { const std::string primitive_name("assign"); const std::string module_name("mindspore.ops.functional"); ValueNodePtr assign_op = NewValueNode(prim::GetPythonOps(primitive_name, module_name, true)); - auto source = ReadVariable(readid); - auto assign = func_graph()->NewCNode({assign_op, target, source}); - WriteVariable(readid, assign); - MS_LOG(INFO) << "SetState read " << target->DebugString() << ", " << readid; - AddAutoDepend(assign); + auto assign = func_graph_->NewCNodeInOrder({assign_op, target, source}); + func_graph_->AddIsolateNode(assign); } -void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); } - -void FunctionBlock::InsertDependItemsBeforeReturn() { - if (!prev_blocks_.empty()) { - for (auto &prev_block : prev_blocks_) { - MS_LOG(DEBUG) << "Has prev_block " << prev_block->func_graph()->debug_info().get(); +void FunctionBlock::FindIsolateVariables() { + // + // Search isolate nodes from variables, for example, + // variable 'a' is an isolate node in below code: + // + // def construct(self, x, y): + // a = print(x) # isolate node + // return x + y + // + std::set used; + // Find used variables. + for (const auto &var : vars_) { + auto &node = var.second.first; + if (node == nullptr) { + continue; + } + bool is_used = var.second.second; + if (is_used) { + used.emplace(node); } } - - ValueNodePtr make_tuple_op = NewValueNode(prim::kPrimMakeTuple); - ValueNodePtr depend_op = NewValueNode(prim::kPrimDepend); - ValueNodePtr stop_gradient_op = NewValueNode(prim::kPrimStopGradient); - - if (auto_depends_.size() == 0) { - return; - } - AnfNodePtr state = nullptr; - std::vector vec_states; - vec_states.emplace_back(make_tuple_op); - for (auto &item : auto_depends_) { - MS_LOG(DEBUG) << "auto_depends " << item->ToString(); - vec_states.emplace_back(item); - } - // if there are only make_tuple_op and another node in vec_states(the vec_states size is 2) - // do not need to make_tuple, just use the node. - if (vec_states.size() == 2) { - state = vec_states[1]; - } else { - state = func_graph()->NewCNode(vec_states); - } - - AnfNodePtr old_ret = nullptr; - auto return_node = func_graph()->get_return(); - if (return_node) { - if (return_node->inputs().size() < 1) { - MS_LOG(EXCEPTION) << "Length of inputs of output node is less than 2"; + // Add isolate nodes which is unused var but not found in used set. + for (const auto &var : vars_) { + auto &node = var.second.first; + bool is_used = var.second.second; + if (node == nullptr || is_used) { + continue; + } + auto &var_name = var.first; + if (used.find(node) == used.end() && CanBeIsolateNode(var_name, node)) { + func_graph_->AddIsolateNode(node); } - old_ret = return_node->input(1); - } else { - old_ret = NewValueNode(kNone); } - AnfNodePtr stopped = func_graph()->NewCNode({stop_gradient_op, state}); - AnfNodePtr ret = func_graph()->NewCNode({depend_op, old_ret, stopped}); - func_graph()->set_output(ret, true); } + } // namespace parse } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/parse/function_block.h b/mindspore/ccsrc/pipeline/jit/parse/function_block.h index d7efba824b..f5c1dce6e4 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/function_block.h +++ b/mindspore/ccsrc/pipeline/jit/parse/function_block.h @@ -61,10 +61,8 @@ class FunctionBlock : public std::enable_shared_from_this { AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi); void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock, bool unroll_loop = true); - // record the assign statement of self.xx weight parameter ,which will use state_setitem op - void SetStateAssgin(const AnfNodePtr &target, const std::string &readid); - void AddAutoDepend(const AnfNodePtr &target); - void InsertDependItemsBeforeReturn(); + // Create cnode for the assign statement like self.target = source. + void SetStateAssign(const AnfNodePtr &target, const AnfNodePtr &source); void AddGlobalVar(const std::string &var_name) { (void)global_vars_.insert(var_name); } bool IsGlobalVar(const std::string &var_name) { return global_vars_.find(var_name) != global_vars_.end(); } AnfNodePtr MakeResolveAstOp(const py::object &op); @@ -73,6 +71,7 @@ class FunctionBlock : public std::enable_shared_from_this { AnfNodePtr MakeResolveOperation(const std::string &value); AnfNodePtr MakeResolve(const std::shared_ptr &name_space, const std::shared_ptr &resolve_symbol); const std::unordered_map &removable_phis() const { return removable_phis_; } + void FindIsolateVariables(); private: // block graph @@ -88,8 +87,8 @@ class FunctionBlock : public std::enable_shared_from_this { // refer to comments in Parser::func_block_list_; std::vector prev_blocks_; - // store args and variable's node - std::map vars_; + // store args and variable's node, use a bool flag to indicate if the variable is used. + std::map> vars_; // phi_nodes map the parameter node to variable, it can be resolved if the block's predecessors are processed std::map phi_nodes_; @@ -110,10 +109,6 @@ class FunctionBlock : public std::enable_shared_from_this { // hold declared global variables in function std::set global_vars_; - // other depend need to insert before function return nodes. - // summary or some other node - std::vector auto_depends_; - // keeps the new made resolve symbol for the variable not found in vars_. std::unordered_map var_to_resolve_; }; diff --git a/mindspore/ccsrc/pipeline/jit/parse/parse.cc b/mindspore/ccsrc/pipeline/jit/parse/parse.cc index 0814e0bcef..b98c296325 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/parse.cc @@ -81,7 +81,7 @@ AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNo return param; } auto cast_helper = prim::kPrimMixedPrecisionCast; - auto cast = func_graph->NewCNode({NewValueNode(cast_helper), NewValueNode(dst_type), param}); + auto cast = func_graph->NewCNodeAfter(param, {NewValueNode(cast_helper), NewValueNode(dst_type), param}); return cast; } @@ -185,7 +185,7 @@ void UpdataParam(const FuncGraphPtr &top_graph, const py::object &cell) { } void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr &ast) { - // check whether the functions refered by this function and itself are missing 'return' statement + // check whether the functions referred by this function and itself are missing 'return' statement auto mng = Manage(fn, false); for (auto func_graph : mng->func_graphs()) { if (func_graph->get_return() != nullptr) { @@ -212,6 +212,11 @@ FuncGraphPtr Parser::ParseFuncGraph() { return nullptr; } + // Add unused variables as isolate nodes. + for (auto &block : func_block_list_) { + block->FindIsolateVariables(); + } + RemoveUnnecessaryPhis(); MS_EXCEPTION_IF_NULL(pFnBlock); @@ -342,15 +347,14 @@ FunctionBlockPtr Parser::ParseFunction(const py::object &node, const FunctionBlo } FunctionBlockPtr Parser::ParseStatements(FunctionBlockPtr fn_block, const py::object &nodes) { - py::int_ pcount = python_adapter::CallPyObjMethod(nodes, "__len__"); - size_t count = LongToSize(pcount); + auto node_list = py::cast(nodes); + size_t count = py::len(node_list); MS_LOG(DEBUG) << "The nodes count is " << count; - for (size_t i = 0; i < count; i++) { - auto node = py::cast(nodes)[i]; + for (size_t i = 0; i < count; ++i) { + auto node = node_list[i]; fn_block = ParseStatement(fn_block, node); // insert appropriate depended items for the function block if it has a return node if (fn_block->func_graph()->get_return() != nullptr) { - fn_block->InsertDependItemsBeforeReturn(); // Skip statements after 'return' (or 'break', 'continue'). break; } @@ -406,7 +410,6 @@ AnfNodePtr Parser::ParseExprNode(const FunctionBlockPtr &block, const py::object } // process the expr statement and expand it -// eg: x.append(y) -> x = x.append(y) FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast Expr"; // Expr only have value , no target @@ -424,10 +427,14 @@ FunctionBlockPtr Parser::ParseExpr(const FunctionBlockPtr &block, const py::obje AnfNodePtr value_node = ParseExprNode(block, value_object); if (py::len(expand_info) == 2) { - // add to depend list and insert before output - block->AddAutoDepend(value_node); + // expression that not assigned to any variable, + // this is usually a call with side effects, + // e.g.: print(x) + // we save it as an isolate node. + value_node->func_graph()->AddIsolateNode(value_node); } else { - // expand the assign statement + // expand the assign statement, + // e.g.: x.append(y) -> x = x.append(y) py::object target_node = expand_info[2]; WriteAssignVars(block, target_node, value_node); } @@ -465,7 +472,7 @@ FunctionBlockPtr Parser::ParseReturn(const FunctionBlockPtr &block, const py::ob py::object value = python_adapter::GetPyObjAttr(node, "value"); AnfNodePtr pReturnStatementNode = ParseExprNode(block, value); // Create the cnode - CNodePtr pReturnCNode = block->func_graph()->NewCNode({pReturnValueNode, pReturnStatementNode}); + CNodePtr pReturnCNode = block->func_graph()->NewCNodeInOrder({pReturnValueNode, pReturnStatementNode}); block->func_graph()->set_return(pReturnCNode); @@ -493,7 +500,7 @@ AnfNodePtr Parser::ParseBinOp(const FunctionBlockPtr &block, const py::object &n // resolve the op AnfNodePtr op_node = block->MakeResolveAstOp(op); // create apply node - return block->func_graph()->NewCNode({op_node, left_node, right_node}); + return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node}); } AnfNodePtr Parser::ParseName(const FunctionBlockPtr &block, const py::object &node) { @@ -592,7 +599,7 @@ AnfNodePtr Parser::GenerateMakeTuple(const FunctionBlockPtr &block, const std::v make_tuple_nodes.push_back(make_tuple_op); (void)std::transform(element_nodes.begin(), element_nodes.end(), std::back_inserter(make_tuple_nodes), [](AnfNodePtr arg) -> AnfNodePtr { return arg; }); - return block->func_graph()->NewCNode(make_tuple_nodes); + return block->func_graph()->NewCNodeInOrder(make_tuple_nodes); } AnfNodePtr Parser::ParseSuper(const FunctionBlockPtr &block, const py::list &args) { @@ -652,7 +659,7 @@ CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_f unpack_call_nodes.push_back(call_function_anf_node); (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes), [](AnfNodePtr node) -> AnfNodePtr { return node; }); - CNodePtr unpack_call = func_graph->NewCNode(unpack_call_nodes); + CNodePtr unpack_call = func_graph->NewCNodeInOrder(unpack_call_nodes); return unpack_call; } @@ -668,7 +675,7 @@ AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const A func_call_nodes.push_back(call_function_anf_node); (void)std::transform(group_arguments.begin(), group_arguments.end(), std::back_inserter(func_call_nodes), [](AnfNodePtr node) -> AnfNodePtr { return node; }); - CNodePtr call_anf_node = block->func_graph()->NewCNode(func_call_nodes); + CNodePtr call_anf_node = block->func_graph()->NewCNodeInOrder(func_call_nodes); return call_anf_node; } @@ -720,7 +727,7 @@ bool Parser::ParseKeywordsInCall(const FunctionBlockPtr &block, const py::object make_dict_nodes.push_back(make_dict_op); make_dict_nodes.push_back(keys_tuple); make_dict_nodes.push_back(values_tuple); - packed_arguments->push_back(block->func_graph()->NewCNode(make_dict_nodes)); + packed_arguments->push_back(block->func_graph()->NewCNodeInOrder(make_dict_nodes)); } return need_unpack; } @@ -770,7 +777,7 @@ AnfNodePtr Parser::ParseAttribute(const FunctionBlockPtr &block, const py::objec } // create the apply node - return block->func_graph()->NewCNode({op_node, value_node, attr_node}); + return block->func_graph()->NewCNodeInOrder({op_node, value_node, attr_node}); } // Process comparison expression : a == b. a > b etc. @@ -793,7 +800,7 @@ AnfNodePtr Parser::ParseCompare(const FunctionBlockPtr &block, const py::object MS_EXCEPTION_IF_NULL(block); AnfNodePtr op_node = block->MakeResolveAstOp(ops[0]); - return block->func_graph()->NewCNode({op_node, left_node, right_node}); + return block->func_graph()->NewCNodeInOrder({op_node, left_node, right_node}); } AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const py::list &value_list, AstSubType mode) { @@ -839,12 +846,12 @@ AnfNodePtr Parser::ProcessBoolOpValueList(const FunctionBlockPtr &block, const p b2->func_graph()->set_output(test_node); auto cond_node = block->ForceToBoolNode(test_node); - auto switch_app = - block->func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), cond_node, NewValueNode(true_block->func_graph()), - NewValueNode(false_block->func_graph())}); + auto switch_app = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), cond_node, + NewValueNode(true_block->func_graph()), + NewValueNode(false_block->func_graph())}); std::vector call_graph_nodes{switch_app}; - auto switch_app_call = block->func_graph()->NewCNode(call_graph_nodes); + auto switch_app_call = block->func_graph()->NewCNodeInOrder(call_graph_nodes); return switch_app_call; } } @@ -855,7 +862,7 @@ AnfNodePtr Parser::ParseBoolOp(const FunctionBlockPtr &block, const py::object & py::object op_node = python_adapter::GetPyObjAttr(node, "op"); AstSubType op_type = ast_->GetOpType(op_node); if (op_type == AST_SUB_TYPE_UNKNOWN) { - MS_LOG(WARNING) << "ProcessBoolOp, got unkown op type"; + MS_LOG(WARNING) << "ProcessBoolOp, got unknown op type"; return nullptr; } py::list op_values = python_adapter::GetPyObjAttr(node, "values"); @@ -919,7 +926,7 @@ AnfNodePtr Parser::ParseTuple(const FunctionBlockPtr &block, const py::object &n AnfNodePtr node_ptr = ParseExprNode(block, elts[i]); tuple_vec.emplace_back(node_ptr); } - CNodePtr tuple_app = block->func_graph()->NewCNode(tuple_vec); + CNodePtr tuple_app = block->func_graph()->NewCNodeInOrder(tuple_vec); return tuple_app; } @@ -940,7 +947,7 @@ AnfNodePtr Parser::ParseList(const FunctionBlockPtr &block, const py::object &no AnfNodePtr node_ptr = ParseExprNode(block, elts[i]); list_vec.emplace_back(node_ptr); } - CNodePtr list_app = block->func_graph()->NewCNode(list_vec); + CNodePtr list_app = block->func_graph()->NewCNodeInOrder(list_vec); return list_app; } @@ -954,7 +961,7 @@ AnfNodePtr Parser::ParseSubscript(const FunctionBlockPtr &block, const py::objec AnfNodePtr value = ParseExprNode(block, value_node); AnfNodePtr slice = ParseExprNode(block, slice_node); - return block->func_graph()->NewCNode({op_getitem, value, slice}); + return block->func_graph()->NewCNodeInOrder({op_getitem, value, slice}); } // process a slice, get the slice value @@ -969,7 +976,7 @@ AnfNodePtr Parser::ParseSlice(const FunctionBlockPtr &block, const py::object &n AnfNodePtr stop_node = ParseExprNode(block, stop); AnfNodePtr step_node = ParseExprNode(block, step); - return block->func_graph()->NewCNode({op_makeslice, start_node, stop_node, step_node}); + return block->func_graph()->NewCNodeInOrder({op_makeslice, start_node, stop_node, step_node}); } // process a extslice @@ -985,7 +992,7 @@ AnfNodePtr Parser::ParseExtSlice(const FunctionBlockPtr &block, const py::object AnfNodePtr node_ptr = ParseExprNode(block, slice_tuple[i]); node_vec.emplace_back(node_ptr); } - CNodePtr tuple_conde = block->func_graph()->NewCNode(node_vec); + CNodePtr tuple_conde = block->func_graph()->NewCNodeInOrder(node_vec); return tuple_conde; } @@ -1007,7 +1014,7 @@ AnfNodePtr Parser::ParseUnaryOp(const FunctionBlockPtr &block, const py::object py::object operand = python_adapter::GetPyObjAttr(node, "operand"); AnfNodePtr operand_node = ParseExprNode(block, operand); - return block->func_graph()->NewCNode({op_node, operand_node}); + return block->func_graph()->NewCNodeInOrder({op_node, operand_node}); } // process a dict ast node expression @@ -1025,7 +1032,7 @@ AnfNodePtr Parser::ParseDict(const FunctionBlockPtr &block, const py::object &no auto values_tuple = GenerateMakeTuple(block, value_nodes); MS_EXCEPTION_IF_NULL(block); auto make_dict_op = block->MakeResolveOperation(NAMED_PRIMITIVE_MAKEDICT); - return block->func_graph()->NewCNode({make_dict_op, keys_tuple, values_tuple}); + return block->func_graph()->NewCNodeInOrder({make_dict_op, keys_tuple, values_tuple}); } // process a augment assign such as a += b or mat[stride_slice] += b. @@ -1054,7 +1061,7 @@ FunctionBlockPtr Parser::ParseAugAssign(const FunctionBlockPtr &block, const py: if (target_node == nullptr) { MS_LOG(EXCEPTION) << "Can not get target node "; } - CNodePtr augassign_app = block->func_graph()->NewCNode({op_node, target_node, value_node}); + CNodePtr augassign_app = block->func_graph()->NewCNodeInOrder({op_node, target_node, value_node}); WriteAssignVars(block, target_obj, augassign_app); return block; } @@ -1180,13 +1187,13 @@ CNodePtr Parser::GenerateIteratorInFor(const FunctionBlockPtr &block, const py:: const AnfNodePtr &op_iter) { py::object iter_node = python_adapter::GetPyObjAttr(node, "iter"); AnfNodePtr iter_anf_node = ParseExprNode(block, iter_node); - return block->func_graph()->NewCNode({op_iter, iter_anf_node}); + return block->func_graph()->NewCNodeInOrder({op_iter, iter_anf_node}); } CNodePtr Parser::GenerateCondInFor(const ParameterPtr &iter_param, const FunctionBlockPtr &header_block, const AnfNodePtr &op_hasnext) { MS_EXCEPTION_IF_NULL(header_block); - return header_block->func_graph()->NewCNode({op_hasnext, iter_param}); + return header_block->func_graph()->NewCNodeInOrder({op_hasnext, iter_param}); } FunctionBlockPtr Parser::GenerateBlockInFor(const TraceInfoPtr &trace_info) { @@ -1225,8 +1232,8 @@ FunctionBlockPtr Parser::ParseFor(const FunctionBlockPtr &block, const py::objec AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); py::object iter_obj = python_adapter::GetPyObjAttr(node, NAMED_PRIMITIVE_ITER); AnfNodePtr iter_node = ParseExprNode(block, iter_obj); - CNodePtr len_iter = block->func_graph()->NewCNode({op_len, iter_node}); - CNodePtr bool_node = block->func_graph()->NewCNode( + CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({op_len, iter_node}); + CNodePtr bool_node = block->func_graph()->NewCNodeInOrder( {NewValueNode(prim::kPrimScalarLt), len_iter, NewValueNode(GetForTransToWhileLoop())}); // create statement 'if len(xs) < prim::MAX_FOR_LOOP_COUNT then ParseForIter else ParseForLoop' @@ -1290,11 +1297,13 @@ FunctionBlockPtr Parser::ParseForIter(const FunctionBlockPtr &block, const py::o body_block->AddPrevBlock(header_block); // generate the iterator next apply // process as following: `app = next(it); target = app[0]; it = app[1];` - CNodePtr app = body_block->func_graph()->NewCNode({op_next, iter_param}); - CNodePtr target_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(static_cast(0))}); + CNodePtr app = body_block->func_graph()->NewCNodeInOrder({op_next, iter_param}); + CNodePtr target_app = + body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast(0))}); py::object target_node = python_adapter::GetPyObjAttr(node, "target"); - CNodePtr iter2_app = body_block->func_graph()->NewCNode({op_getitem, app, NewValueNode(static_cast(1))}); + CNodePtr iter2_app = + body_block->func_graph()->NewCNodeInOrder({op_getitem, app, NewValueNode(static_cast(1))}); WriteAssignVars(body_block, target_node, target_app); // link the variable name with the target @@ -1351,7 +1360,7 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o AnfNodePtr op_len = block->MakeResolveSymbol(NAMED_PRIMITIVE_LEN); AnfNodePtr op_getitem = block->MakeResolveOperation(NAMED_PRIMITIVE_GETITEM); - // get varibale name of 'x' in statement 'for x in xs' + // get variable name of 'x' in statement 'for x in xs' py::object target_node = python_adapter::GetPyObjAttr(node, "target"); // create statement 'len(xs)' @@ -1359,11 +1368,11 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o AnfNodePtr iter_node = ParseExprNode(block, iter_obj); MS_EXCEPTION_IF_NULL(iter_node); // Generate node for loop count and convert it to tensor, to make the loop not unroll - CNodePtr scalar_len = block->func_graph()->NewCNode({op_len, iter_node}); + CNodePtr scalar_len = block->func_graph()->NewCNodeInOrder({op_len, iter_node}); auto scalar_to_tensor = prim::GetPythonOps("ScalarToTensor", "mindspore.ops.operations"); - auto scalar_to_tensor_node = block->func_graph()->NewCNode({NewValueNode(scalar_to_tensor)}); + auto scalar_to_tensor_node = block->func_graph()->NewCNodeInOrder({NewValueNode(scalar_to_tensor)}); - CNodePtr len_iter = block->func_graph()->NewCNode({scalar_to_tensor_node, scalar_len}); + CNodePtr len_iter = block->func_graph()->NewCNodeInOrder({scalar_to_tensor_node, scalar_len}); FunctionBlockPtr header_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); @@ -1372,18 +1381,18 @@ FunctionBlockPtr Parser::ParseForLoop(const FunctionBlockPtr &block, const py::o ParameterPtr loop_var = header_block->func_graph()->add_parameter(); // create loop condition 'i < len(xs)' auto prim_less = prim::GetPythonOps("Less", "mindspore.ops.operations"); - auto less_node = header_block->func_graph()->NewCNode({NewValueNode(prim_less)}); - CNodePtr cond_node = header_block->func_graph()->NewCNode({less_node, loop_var, len_iter}); + auto less_node = header_block->func_graph()->NewCNodeInOrder({NewValueNode(prim_less)}); + CNodePtr cond_node = header_block->func_graph()->NewCNodeInOrder({less_node, loop_var, len_iter}); // generate the body of the for statement FunctionBlockPtr body_block = GenerateBlockInFor(std::make_shared(block->func_graph()->debug_info())); MS_EXCEPTION_IF_NULL(body_block); body_block->AddPrevBlock(header_block); // create 'x = xs[i]' - CNodePtr target_var = body_block->func_graph()->NewCNode({op_getitem, iter_node, loop_var}); + CNodePtr target_var = body_block->func_graph()->NewCNodeInOrder({op_getitem, iter_node, loop_var}); WriteAssignVars(body_block, target_node, target_var); // create 'i = i + 1' - CNodePtr loop_var_inc = body_block->func_graph()->NewCNode( + CNodePtr loop_var_inc = body_block->func_graph()->NewCNodeInOrder( {NewValueNode(prim::kPrimScalarAdd), loop_var, NewValueNode(static_cast(1))}); body_block->WriteVariable(loop_var->name(), loop_var_inc); @@ -1461,12 +1470,12 @@ AnfNodePtr Parser::ParseIfExp(const FunctionBlockPtr &block, const py::object &n // Use the Primitive replace the operation resolve node (switch) // because the switch will eventually be converted to Primitive node - CNodePtr switch_app = - block->func_graph()->NewCNode({NewValueNode(prim::kPrimSwitch), bool_node, NewValueNode(true_block->func_graph()), - NewValueNode(false_block->func_graph())}); + CNodePtr switch_app = block->func_graph()->NewCNodeInOrder({NewValueNode(prim::kPrimSwitch), bool_node, + NewValueNode(true_block->func_graph()), + NewValueNode(false_block->func_graph())}); std::vector call_graph_nodes{switch_app}; - CNodePtr switch_app_call = block->func_graph()->NewCNode(call_graph_nodes); + CNodePtr switch_app_call = block->func_graph()->NewCNodeInOrder(call_graph_nodes); return switch_app_call; } @@ -1495,7 +1504,7 @@ void Parser::HandleAssignTuple(const FunctionBlockPtr &block, const py::object & // Use the Primitive replace the operation resolve node (getitem) // because the getitem will eventually be converted to Primitive node CNodePtr item_apply = - block->func_graph()->NewCNode({op_getitem, assigned_node, NewValueNode(static_cast(i))}); + block->func_graph()->NewCNodeInOrder({op_getitem, assigned_node, NewValueNode(static_cast(i))}); py::object elt = items[i]; WriteAssignVars(block, elt, item_apply); @@ -1509,9 +1518,7 @@ void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::ob MS_EXCEPTION_IF_NULL(target_node); std::string attr_name = targ.attr("attr").cast(); - std::string var_name = "self."; - (void)var_name.append(attr_name); - MS_LOG(DEBUG) << "assign " << var_name; + std::string var_name = "self." + attr_name; // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) { @@ -1526,9 +1533,8 @@ void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::ob } MS_EXCEPTION_IF_NULL(block); - block->WriteVariable(var_name, assigned_node); MS_LOG(DEBUG) << "SetState write " << var_name << " : " << target_node->ToString(); - block->SetStateAssgin(target_node, var_name); + block->SetStateAssign(target_node, assigned_node); } void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::object &targ, @@ -1539,7 +1545,7 @@ void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::obje py::object slice_obj = python_adapter::GetPyObjAttr(targ, "slice"); AnfNodePtr value_node = ParseExprNode(block, value_obj); AnfNodePtr slice_node = ParseExprNode(block, slice_obj); - CNodePtr setitem_app = block->func_graph()->NewCNode({op_setitem, value_node, slice_node, assigned_node}); + CNodePtr setitem_app = block->func_graph()->NewCNodeInOrder({op_setitem, value_node, slice_node, assigned_node}); // getitem apply should return the sequence data structure itself std::string var_name; if (ast_->IsClassMember(value_obj)) { @@ -1590,7 +1596,7 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta // process a assign statement, such as a =b, a,b = tup FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) { - MS_LOG(DEBUG) << "Process ast assgin"; + MS_LOG(DEBUG) << "Process ast assign"; py::object value_object = python_adapter::GetPyObjAttr(node, "value"); AnfNodePtr value_node = ParseExprNode(block, value_object); py::object targets_object = python_adapter::GetPyObjAttr(node, "targets"); @@ -1667,7 +1673,6 @@ void Parser::RemoveUnnecessaryPhis() { for (int64_t idx = SizeToLong(phis.size() - 1); idx >= 0; idx--) { auto phi = phis[LongToSize(idx)]; auto new_node = FindPhis(removable_phis, phi); - MS_LOG(DEBUG) << "phi " << phi->DebugString() << " to " << new_node->DebugString(); mng->Replace(phi, new_node); } // remove the parameter @@ -1837,7 +1842,7 @@ static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr func_graph, std::size_t index = 0; std::vector old_cnodes; old_cnodes.emplace_back(param_node); - auto res = func_graph->NewCNode({}); + auto res = func_graph->NewCNodeInOrder({}); std::vector new_cnodes; new_cnodes.emplace_back(res); while (index < old_cnodes.size()) { @@ -1851,7 +1856,7 @@ static AnfNodePtr CopyNodesFromParamDefaultValue(const FuncGraphPtr func_graph, AnfNodePtr input = *it; if (input != nullptr && input->isa()) { old_cnodes.emplace_back(input); - auto new_cnode = func_graph->NewCNode({}); + auto new_cnode = func_graph->NewCNodeInOrder({}); new_cnodes.emplace_back(new_cnode); current_new_cnode->add_input(new_cnode); } else if (input->isa()) { @@ -1905,7 +1910,7 @@ FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) { auto ¶ms = func_graph->parameters(); (void)std::transform(params.begin(), params.end(), std::back_inserter(inputs), [](AnfNodePtr node) -> AnfNodePtr { return node; }); - func_graph->set_output(func_graph->NewCNode(inputs)); + func_graph->set_output(func_graph->NewCNodeInOrder(inputs)); } else { // ret = cell_obj(*arg, *kwargs) auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), func_graph->parameters()); diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index 254b4be403..69c636dc0a 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -109,7 +109,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object node->set_abstract(abs); para_node = node; } - func_graph->add_parameter_obj_node(para_node); + func_graph->add_used_global_parameters(para_node); return para_node; } diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index b062651755..9d1ab1a7c7 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -32,7 +32,6 @@ #include "frontend/optimizer/graph_kernel_reuse.h" #include "frontend/optimizer/clean.h" #include "frontend/optimizer/irpass.h" -#include "frontend/optimizer/control_depend.h" #include "frontend/optimizer/graph_transform.h" #include "frontend/parallel/step_parallel.h" #include "frontend/parallel/step_auto_parallel.h" @@ -41,9 +40,11 @@ #include "frontend/optimizer/recompute.h" #include "utils/log_adapter.h" #include "pipeline/jit/pipeline_split.h" +#include "pipeline/jit/static_analysis/auto_monad.h" #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "ps/util.h" #endif + namespace mindspore { namespace pipeline { using OptPassGroupMap = opt::OptPassGroupMap; @@ -90,13 +91,19 @@ bool CleanAfterOptAPass(const ResourcePtr &res) { } namespace { +bool ReAutoMonadWrapper(const FuncGraphPtr &root, const opt::OptimizerPtr &) { return ReAutoMonad(root); } + OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig a_1 = opt::OptPassConfig({ irpass.switch_layer_defer_inline_, irpass.switch_simplify_, + irpass.exchange_switch_depend_value_, + irpass.float_depend_g_call_, // Safe inlining irpass.inline_, + irpass.updatestate_eliminater_, + irpass.stopgrad_eliminater_, irpass.partial_eliminate_, irpass.replace_applicator_, @@ -122,6 +129,8 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { // Safe inlining irpass.inline_, + irpass.updatestate_eliminater_, + irpass.stopgrad_eliminater_, irpass.sparse_tensor_eliminate_, }); opt::OptPassConfig a_2 = opt::OptPassConfig({ @@ -147,6 +156,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.switch_layer_defer_inline_, irpass.replace_applicator_, irpass.mirror_mini_step_elim_, + irpass.row_tensor_add_zeros_like_, }); opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); opt::OptPassConfig grad = opt::OptPassConfig({irpass.expand_jprim_}, true); @@ -167,6 +177,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { {"resolve", resolve_pass}, {"a_after_grad", a_after_grad}, {"renormalize", opt::OptPassConfig::Renormalize()}, + {"auto_monad_grad", opt::OptPassConfig(ReAutoMonadWrapper)}, {"cse", opt::OptPassConfig(opt::CSEPass(false))}, {"a_3", a_3}}); @@ -183,6 +194,9 @@ OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irp opt::OptPassConfig c_1 = opt::OptPassConfig({ // Safe inlining, irpass.inline_, + irpass.updatestate_eliminater_, + irpass.switch_call_monad_eliminater_, + irpass.stopgrad_eliminater_, irpass.partial_eliminate_, }); @@ -206,8 +220,9 @@ OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig b_1 = opt::OptPassConfig( {irpass.zero_like_fill_zero_, irpass.item_tuple_or_list_eliminate_, irpass.float_tuple_getitem_switch_, - irpass.reset_defer_inline_, irpass.inline_, irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, - irpass.incorporate_env_getitem_, irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, + irpass.reset_defer_inline_, irpass.inline_, irpass.updatestate_eliminater_, irpass.stopgrad_eliminater_, + irpass.special_op_eliminate_, irpass.get_make_ref_eliminate_, irpass.incorporate_env_getitem_, + irpass.incorporate_env_getitem_switch_, irpass.env_get_item_eliminate_, irpass.incorporate_env_getitem_switch_layer_, irpass.value_based_eliminate_, irpass.receive_eliminate_}); opt::OptPassConfig b_2 = opt::OptPassConfig({ irpass.replace_refkey_by_param_, @@ -380,22 +395,6 @@ bool OptPassRNGroup(const ResourcePtr &res) { return OptPassGroup(res, "renormal bool OptPassGradEpilogueGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_grad_epilogue"); } -bool AddControlDependPass(const ResourcePtr &res) { - FuncGraphPtr func_graph = res->func_graph(); - MS_EXCEPTION_IF_NULL(func_graph); - - if (func_graph->has_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER)) { - opt::AddControlDepend(func_graph); - } - for (auto fg : func_graph->func_graphs_used_total()) { - MS_EXCEPTION_IF_NULL(fg); - if (fg->has_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER)) { - opt::AddControlDepend(fg); - } - } - return true; -} - bool AddRecomputationPass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res); opt::InsertRecomputedNodes(res->func_graph()); @@ -532,7 +531,6 @@ std::vector kVmPasses = {{"simplify_data_structures", SimplifyDataStru {"opt_graph_kernel_a", OptPassGraphKernelGroupA}, {"opt_graph_kernel_b", OptPassGraphKernelGroupB}, {"add_cache_embedding", AddCacheEmbeddingPass}, - {"add_control_depend", AddControlDependPass}, {"add_recomputation", AddRecomputationPass}, {"cse_after_recomputation", OptAfterRecomputeGroup}}; @@ -540,7 +538,6 @@ std::vector kGePasses = {{"simplify_data_structures", SimplifyDataStru {"opt_a", OptPassAGroup}, {"clean_after_opta", CleanAfterOptAPass}, {"opt_b", OptPassBGroup}, - {"add_control_depend", AddControlDependPass}, {"opt_control", ControlGroup}, {"opt_prepare", PrepareGroup}, {"cconv", CconvPass}}; diff --git a/mindspore/ccsrc/pipeline/jit/pass.h b/mindspore/ccsrc/pipeline/jit/pass.h index 1285fd14bb..50abeea7bf 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.h +++ b/mindspore/ccsrc/pipeline/jit/pass.h @@ -36,7 +36,6 @@ bool CconvPass(const ResourcePtr &res); bool PipelineSplitPass(const ResourcePtr &res); bool ValidatePass(const ResourcePtr &res); bool ConvertPrepareAdapt(const ResourcePtr &res); -bool AddControlDependPass(const ResourcePtr &res); bool AddCacheEmbeddingPass(const ResourcePtr &res); bool InferenceOptPreparePass(const ResourcePtr &res); void ReclaimOptimizer(); diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 03c3bdeda0..d7b9cdeffe 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -732,23 +732,6 @@ void Pipeline::Run() { // generate IR file in a heavily commented format, which can also be reloaded ExportIR(base_name + ".dat", std::to_string(i), graph); } -#ifdef MS_DEBUG - // Dump graph cnode list - MS_LOG(INFO) << "Show CNode list after " << action.first; - graph->DumpCNodeList(); -#endif - } - if (resource_->func_graph() != nullptr) { - auto func_graph = resource_->func_graph(); - if (func_graph->has_flag(GRAPH_FLAG_HAS_EFFECT)) { - func_graph->EraseUnusedNodeInOrder(); - func_graph->CheckOrder(); - for (auto fg : func_graph->func_graphs_used_total()) { - MS_LOG(DEBUG) << "Check order graph " << fg->ToString() << "."; - fg->EraseUnusedNodeInOrder(); - fg->CheckOrder(); - } - } } i++; #ifdef ENABLE_TIMELINE diff --git a/mindspore/ccsrc/pipeline/jit/resource.cc b/mindspore/ccsrc/pipeline/jit/resource.cc index cb746d8b85..ed111d7bec 100644 --- a/mindspore/ccsrc/pipeline/jit/resource.cc +++ b/mindspore/ccsrc/pipeline/jit/resource.cc @@ -185,6 +185,10 @@ BuiltInTypeMap &GetMethodMap() { {"astype", std::string("astype")}, // P.cast() {"__bool__", std::string("tensor_bool")}, // C.tensor_bool }}, + {kObjectTypeRowTensorType, + { + {"__add__", prim::kPrimRowTensorAdd}, // P.row_tensor_add + }}, {kObjectTypeJTagged, {}}, {kObjectTypeSymbolicKeyType, {}}, {kObjectTypeEnvType, {}}}; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc new file mode 100644 index 0000000000..4c3eae0362 --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.cc @@ -0,0 +1,1430 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pipeline/jit/static_analysis/auto_monad.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include "pipeline/jit/parse/resolve.h" +#include "frontend/operator/ops.h" +#include "frontend/operator/composite/multitype_funcgraph.h" +#include "utils/flags.h" +#include "utils/ordered_map.h" +#include "base/core_ops.h" +#include "abstract/abstract_value.h" + +namespace mindspore::pipeline { +namespace { // namespace anonymous + +using ClassTypePtr = std::shared_ptr; +using RefInputs = OrderedMap>; + +// Add or get a monad parameter. +AnfNodePtr AddMonadParameter(const FuncGraphPtr &func_graph, const std::string &name, + const abstract::AbstractBasePtr &abs) { + // Search for existed parameters, return it if found. + for (auto &node : func_graph->parameters()) { + auto para = dyn_cast(node); + if (para == nullptr) { + continue; + } + auto para_abs = para->abstract(); + if (para_abs && *para_abs == *abs) { + return para; + } + } + // Create a new parameter if not existed. + auto para = std::make_shared(func_graph); + para->set_name(name); + para->debug_info()->set_name(name); + para->set_abstract(abs); + func_graph->add_parameter(para); + return para; +} + +// Gets side effect propagate attribute value from a ClassType object. +int GetSideEffectPropagate(const ClassTypePtr &class_type) { + if (class_type) { + auto obj = class_type->obj(); + if (py::hasattr(obj, GRAPH_FLAG_SIDE_EFFECT_PROPAGATE)) { + auto value = py::getattr(obj, GRAPH_FLAG_SIDE_EFFECT_PROPAGATE); + return value.cast(); + } + } + return 0; +} + +// Gets 'side_effect_propagate' attribute value from a primitive. +int GetSideEffectPropagate(const PrimitivePtr &prim) { + if (prim) { + auto attr = prim->GetAttr(GRAPH_FLAG_SIDE_EFFECT_PROPAGATE); + if (attr && attr->isa()) { + return static_cast(attr->cast()->value()); + } + } + return 0; +} + +// Return true if the node has Ref abstract. +bool HasAbstractRef(const AnfNodePtr &node) { + if (node == nullptr) { + return false; + } + auto &abs = node->abstract(); + return (abs != nullptr) && abs->isa(); +} + +// Gets ref inputs and its indexes from a cnode. +RefInputs GetRefInputs(const CNodePtr &cnode) { + RefInputs ref_inputs; + for (size_t i = 1; i < cnode->size(); ++i) { + auto &input = cnode->inputs().at(i); + if (HasAbstractRef(input)) { + ref_inputs[input].push_back(i); + } + } + return ref_inputs; +} + +// Return true if cnode has ref input. +bool HasRefInput(const CNodePtr &cnode) { + if (cnode == nullptr || cnode->inputs().empty()) { + return false; + } + auto &inputs = cnode->inputs(); + // Return true if any of arguments is ref. + return std::any_of(inputs.begin() + 1, inputs.end(), [](const auto &input) { return HasAbstractRef(input); }); +} + +// Return true if we don't need Load for the given primitive. +// i.e. keep Ref as Ref for some primitives. +bool IsKeepRef(const PrimitivePtr &prim) { + return (GetSideEffectPropagate(prim) != 0) || IsPrimitiveEquals(prim, prim::kPrimRefToEmbed) || + IsPrimitiveEquals(prim, prim::kPrimPull); +} + +// Gets primitive if the node is a primitive value node. +PrimitivePtr GetPrimitive(const AnfNodePtr &node) { + PrimitivePtr prim = GetValueNode(node); + auto do_sig = dyn_cast(prim); + if (do_sig) { + auto val = do_sig->function(); + return dyn_cast(val); + } + return prim; +} + +// Gets primitive from the given cnode, return nullptr if cnode.inputs[0] is not a primitive. +PrimitivePtr GetPrimitive(const CNodePtr &cnode) { + if (cnode == nullptr || cnode->inputs().empty()) { + return nullptr; + } + return GetPrimitive(cnode->input(0)); +} + +// Gets func_graph from the given cnode, return nullptr if it is not a func graph call. +FuncGraphPtr GetFuncGraph(const CNodePtr &cnode) { + if (cnode != nullptr && !cnode->inputs().empty()) { + return GetValueNode(cnode->input(0)); + } + return nullptr; +} + +// Gets class_type from the given cnode->inputs[0]. +ClassTypePtr GetClassType(const CNodePtr &cnode) { + if (cnode && !cnode->inputs().empty()) { + auto apply = cnode->input(0); + auto apply_cnode = dyn_cast(apply); + if (apply_cnode && !apply_cnode->inputs().empty()) { + return GetValueNode(apply_cnode->input(0)); + } + } + return nullptr; +} + +// Gets first input as cnode from the given cnode, +// return null if input[0] is not a cnode. +CNodePtr GetFuncCNode(const CNodePtr &cnode) { + if (cnode != nullptr && !cnode->inputs().empty()) { + return dyn_cast(cnode->input(0)); + } + return nullptr; +} + +// Gets first input as function parameter from the given cnode, +// return null if input[0] is not a parameter. +ParameterPtr GetFuncParameter(const CNodePtr &cnode) { + if (cnode != nullptr && !cnode->inputs().empty()) { + return dyn_cast(cnode->input(0)); + } + return nullptr; +} + +// Gets first input as MultitypeFuncGraph from the given cnode, +// return null if input[0] is not a MultitypeFuncGraph. +prim::MultitypeFuncGraphPtr GetFuncMultitypeFuncGraph(const CNodePtr &cnode) { + if (cnode != nullptr && !cnode->inputs().empty()) { + return GetValueNode(cnode->input(0)); + } + return nullptr; +} + +// -------------------------------------------------------------------- +// SCC (Strongly Connected Components) related types. +// -------------------------------------------------------------------- +using SccVector = std::set; +using SccPtr = std::shared_ptr; +using SccMap = std::unordered_map; + +// --------------------------------------------------------------------- +// SccFinder find SCCs using Tarjan's algorithm. +// --------------------------------------------------------------------- +class SccFinder { + public: + explicit SccFinder(FuncGraphPtr root) : root_(root) {} + ~SccFinder() = default; + void Run() { (void)Search(root_); } + SccMap &scc_map() { return scc_map_; } + + private: + // Save state of a func graph. + struct State { + size_t index = 0; + size_t lowlink = 0; + bool in_stack = false; + explicit State(size_t index) : index(index), lowlink(index), in_stack(false) {} + ~State() = default; + }; + + // Search SCCs from the given graph. + const State &Search(FuncGraphPtr graph) { + // Create graph state, set it as visited. + auto [inserted, ok] = visited_.emplace(graph, State(index_++)); + if (!ok) { + MS_LOG(EXCEPTION) << "Already visited: " << graph->ToString(); + } + auto &state = inserted->second; + // Push visited graph to stack. + stack_.push(graph); + state.in_stack = true; + // Search successor graphs. + for (auto &used : graph->func_graphs_used()) { + auto &sg = used.first; + auto iter = visited_.find(sg); + if (iter == visited_.end()) { + // Successor graph has not yet been visited, recurse on it. + auto &sg_state = Search(sg); + state.lowlink = std::min(state.lowlink, sg_state.lowlink); + } else if (iter->second.in_stack) { + // Successor graph is in stack and hence in the current SCC. + state.lowlink = std::min(state.lowlink, iter->second.index); + } + } + // If index == lowlink, this means it is the root of SCC. + if (state.index == state.lowlink) { + // Pop members of the SCC from stack, they are on top of its root. + auto scc = std::make_shared(); + while (!stack_.empty()) { + auto g = stack_.top(); + stack_.pop(); + auto found = visited_.find(g); + if (found == visited_.end()) { + MS_LOG(EXCEPTION) << "Unexpected graph: " << g->ToString(); + } + found->second.in_stack = false; + // Add graph to SCC, and create the map from graph to SCC. + scc->insert(g); + scc_map_.emplace(g, scc); + if (g == graph) { + break; + } + } + // SCC should not be empty. + if (scc->empty()) { + MS_LOG(EXCEPTION) << "Invalid SCC for: " << graph->ToString(); + } + } + return state; + } + + private: + // The root graph. + FuncGraphPtr root_; + + // Current index by DFS order. + size_t index_ = 1; + + // Visited graphs and their states. + std::unordered_map visited_; + + // The stack for Tarjan algorithm. + std::stack stack_; + + // The result SCC map, from graph to its SCC. + SccMap scc_map_; +}; + +struct SwitchLayerCall { + CNodePtr caller; + EffectInfo effect_info; + std::vector branches; +}; + +// ------------------------------------------------------------------------------- +// SideEffectFinder search and mark side effects for graph and its sub-graphs. +// ------------------------------------------------------------------------------- +class SideEffectFinder { + public: + static void Search(const FuncGraphPtr &root) { + SideEffectFinder finder(root); + finder.Run(); + } + + private: + explicit SideEffectFinder(const FuncGraphPtr &root) : root_(root) {} + ~SideEffectFinder() = default; + + void Run() { + // To handle recursive calls, we generate SCC map before search. + GenerateSccMap(); + // Update order list to include outer cnodes. + UpdateOrderLists(); + // Find side effects by DFS from the top graph. + (void)GetEffectInfo(root_); + // Check switch layer calls, add monad arguments if need. + HandleSwitchLayerCalls(); + } + + void UpdateOrderLists() { + // Some cnodes used in current func graph but belong to other func graph, we have to + // insert them into order list so that we can handle side effects for them. + UpdateOrderList(root_); + for (auto &fg : root_->func_graphs_used_total()) { + UpdateOrderList(fg); + } + } + + static void UpdateOrderList(const FuncGraphPtr &func_graph) { + std::list new_order_list; + const auto &order_list = func_graph->order_list(); + for (auto &cnode : order_list) { + PushToOrderList(func_graph, cnode, &new_order_list); + } + func_graph->set_order_list(std::move(new_order_list)); + } + + static void PushToOrderList(const FuncGraphPtr &fg, const CNodePtr &cnode, std::list *new_order_list) { + MS_EXCEPTION_IF_NULL(cnode); + auto iter = std::find(new_order_list->begin(), new_order_list->end(), cnode); + if (iter != new_order_list->end()) { + return; + } + for (auto &input : cnode->inputs()) { + auto input_cnode = dyn_cast(input); + if (input_cnode != nullptr && input_cnode->func_graph() != fg) { + PushToOrderList(fg, input_cnode, new_order_list); + } + } + new_order_list->push_back(cnode); + } + + // Generate SCC map by SccFinder. + void GenerateSccMap() { + SccFinder scc_finder(root_); + scc_finder.Run(); + scc_map_ = std::move(scc_finder.scc_map()); + } + + // Gets branch graph from a switch cnode at given input index. + FuncGraphPtr GetSwitchBranch(const CNodePtr &cnode, size_t index) { + return GetValueNode(cnode->inputs().at(index)); + } + + // Gets branch graphs from a switch cnode. + std::vector GetSwitchBranches(const CNodePtr &cnode) { + constexpr size_t switch_cnode_size = 4; + constexpr size_t true_index = 2; + constexpr size_t false_index = 3; + // Check size. + if (cnode->size() != switch_cnode_size) { + MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString(); + } + // Add both branches, in some case, only one branch is set. + std::vector branches; + auto true_branch = GetSwitchBranch(cnode, true_index); + if (true_branch != nullptr) { + branches.emplace_back(true_branch); + } + auto false_branch = GetSwitchBranch(cnode, false_index); + if (false_branch != nullptr) { + branches.emplace_back(false_branch); + } + if (branches.empty()) { + MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString(); + } + return branches; + } + + // Add monad parameter to switch branch graphs. + void AddMonadParameters(const std::vector &branches, const std::string &name, + const AbstractBasePtr &abs) { + for (auto &branch : branches) { + (void)AddMonadParameter(branch, name, abs); + } + } + + // Trace effect info for Switch cnode. + EffectInfo TraceSwitchEffectInfo(const CNodePtr &cnode) { + // Find branches from switch cnode. + auto branches = GetSwitchBranches(cnode); + // For some case, only one branch is set. + if (branches.size() == 1) { + auto &branch = branches.front(); + // Save branch caller, so that we can update arguments for the caller. + SaveBranchCaller(cnode, branch); + return GetEffectInfo(branch); + } + // When both branches are set, merge their effect infos. + EffectInfo info = MergeEffectInfo(branches); + if (info.state == EffectInfo::kDetected) { + // Setup both branches according the merged effect info. + SetupEffectBranches(info, branches); + } + return info; + } + + // Trace effect info for SwitchLayer cnode. + EffectInfo TraceSwitchLayerEffectInfo(const CNodePtr &cnode) { + // Find branches from switch_layer cnode. + auto branches = GetSwitchLayerBranches(cnode); + // Merge effect info from all branches. + EffectInfo info = MergeEffectInfo(branches); + if (info.state == EffectInfo::kDetected) { + // Setup branches according the merged effect info. + SetupEffectBranches(info, branches); + // Save the switch_layer call, so that we can add monad argument for it if need. + auto &call = switch_layer_calls.emplace_back(); + call.caller = caller_; + call.effect_info = info; + call.branches = move(branches); + } + return info; + } + + void HandleSwitchLayerCalls() { + for (auto &call : switch_layer_calls) { + const auto &info = call.effect_info; + const auto &branches = call.branches; + auto new_info = MergeEffectInfo(branches); + // Reset branches if effect info changed. + if (new_info.memory != info.memory || new_info.load != info.load || new_info.io != info.io) { + AddMonadForCaller(call.caller, new_info); + SetupEffectBranches(new_info, branches); + } + } + } + + // Gets branch graphs from a switch_layer cnode. + std::vector GetSwitchLayerBranches(const CNodePtr &cnode) { + constexpr size_t func_tuple_index = 2; + if (cnode->size() <= func_tuple_index) { + MS_LOG(EXCEPTION) << "Invalid switch_layer: " << cnode->DebugString(2); + } + auto func_tuple = cnode->inputs().at(func_tuple_index); + return GetGraphsFromTuple(func_tuple); + } + + // Get and trace graphs from a tuple of func node for switch_layer. + std::vector GetGraphsFromTuple(const AnfNodePtr &func_tuple) { + // The func tuple maker. + if (IsPrimitiveCNode(func_tuple, prim::kPrimMakeTuple)) { + return GetGraphsFromMakeTuple(func_tuple->cast()); + } + // Trace tuple from parameter. + auto para = dyn_cast(func_tuple); + if (para != nullptr) { + std::vector graphs; + ForEachRealArguments(para, + [this, &graphs](const AnfNodePtr &arg) { graphs = std::move(GetGraphsFromTuple(arg)); }); + return graphs; + } + // Trace tuple returned from func graph call. + auto cnode = dyn_cast(func_tuple); + auto func_graph = GetFuncGraph(cnode); + if (func_graph != nullptr) { + return GetGraphsFromTuple(func_graph->output()); + } + MS_LOG(EXCEPTION) << "Invalid input for switch_layer: " << func_tuple->DebugString(2); + } + + // Get graphs from a tuple of funcs make node for switch_layer. + std::vector GetGraphsFromMakeTuple(const CNodePtr &make_tuple) { + auto &inputs = make_tuple->inputs(); + if (inputs.size() <= 1) { + MS_LOG(EXCEPTION) << "Invalid make_tuple for switch_layer: " << make_tuple->DebugString(2); + } + std::vector graphs; + graphs.reserve(inputs.size() - 1); + for (size_t i = 1; i < inputs.size(); ++i) { + auto func_graph = GetValueNode(inputs.at(i)); + if (func_graph == nullptr) { + MS_LOG(WARNING) << "Non-graph found in switch_layer input: " << make_tuple->DebugString(2) << " index=" << i; + continue; + } + graphs.push_back(func_graph); + } + return graphs; + } + + // Trace effect info from tuple_getitem cnode. + EffectInfo TraceTupleGetItemEffectInfo(const CNodePtr &cnode, std::stack *tuple_indexes) { + constexpr size_t tuple_input = 1; + constexpr size_t index_input = 2; + constexpr size_t cnode_size = 3; + if (cnode->size() != cnode_size) { + MS_LOG(EXCEPTION) << "Invalid tuple_getitem: " << cnode->DebugString(); + } + // Get item index. + auto &index_node = cnode->inputs().at(index_input); + auto index_value = GetValueNode(index_node); + if (index_value == nullptr) { + MS_LOG(EXCEPTION) << "Tuple_getitem with non-const index " << cnode->DebugString(); + } + int64_t index = index_value->value(); + + // Get tuple value. + const auto &tuple_node = cnode->inputs().at(tuple_input); + // Push tuple index. + tuple_indexes->push(index); + return TraceTupleEffectInfo(tuple_node, tuple_indexes); + } + + EffectInfo TraceTupleEffectInfo(const AnfNodePtr &tuple_node, std::stack *tuple_indexes) { + auto para = dyn_cast(tuple_node); + if (para != nullptr) { + return TraceTupleParaEffectInfo(para, tuple_indexes); + } + auto tuple_cnode = dyn_cast(tuple_node); + if (tuple_cnode != nullptr) { + return TraceTupleCNodeEffectInfo(tuple_cnode, tuple_indexes); + } + // Should not reach here. + MS_LOG(EXCEPTION) << "Side effects untraceable: " << tuple_node->DebugString(); + } + + EffectInfo TraceTupleParaEffectInfo(const ParameterPtr ¶, std::stack *tuple_indexes) { + EffectInfo info{EffectInfo::kDetected, false, false, false}; + ForEachRealArguments(para, [this, &info, tuple_indexes](const AnfNodePtr &arg) { + // Merge real argument effect info. + auto tuple_indexes_copy = *tuple_indexes; + auto arg_info = TraceTupleEffectInfo(arg, &tuple_indexes_copy); + info.Merge(arg_info); + }); + return info; + } + + EffectInfo TraceTupleCNodeEffectInfo(const CNodePtr &cnode, std::stack *tuple_indexes) { + auto prim = GetPrimitive(cnode); + // Trace MakeTuple. + if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) { + if (tuple_indexes->empty()) { + MS_LOG(EXCEPTION) << "Unexpected make_tuple: " << cnode->DebugString(2); + return {EffectInfo::kDetected, false, false, false}; + } + // Pop out tuple index. + auto index = tuple_indexes->top(); + tuple_indexes->pop(); + // Follow the tuple item according the index. + size_t input_index = static_cast(index) + 1; + if (input_index >= cnode->size()) { + MS_LOG(EXCEPTION) << "Invalid make_tuple: " << cnode->DebugString() << " index=" << index; + } + if (tuple_indexes->empty()) { + // Trace non-tuple. + return TraceEffectInfo(cnode->inputs().at(input_index)); + } + // This is the tuple of tuple case. + return TraceTupleEffectInfo(cnode->inputs().at(input_index), tuple_indexes); + } + // Trace TupleGetItem (tuple of tuple). + if (IsPrimitiveEquals(prim, prim::kPrimTupleGetItem)) { + return TraceTupleGetItemEffectInfo(cnode, tuple_indexes); + } + // Trace primitive propagating side effect from its input, such as Depend, Identity, etc. + int input_index = GetSideEffectPropagate(prim); + if (input_index > 0 && input_index < static_cast(cnode->size())) { + return TraceTupleEffectInfo(cnode->input(static_cast(input_index)), tuple_indexes); + } + // Tuple returned from func graph call. + auto func_graph = GetFuncGraph(cnode); + if (func_graph != nullptr) { + return TraceTupleEffectInfo(func_graph->output(), tuple_indexes); + } + // Tuple is returned from J(). + // %1 = J(primal) + // tuple = %1(args) + if (cnode->size() > 0 && IsPrimitiveCNode(cnode->input(0), prim::kPrimJ)) { + MS_LOG(DEBUG) << "Tuple from J: " << cnode->DebugString(2); + return {EffectInfo::kDetected, false, false, false}; + } + // Rare case. + MS_LOG(WARNING) << "Tuple untraceable from: " << cnode->DebugString(2); + return {EffectInfo::kDetected, false, false, false}; + } + + // Setup all branches according the effect info. + void SetupEffectBranches(const EffectInfo &info, const std::vector &branches) { + // Setup monad parameters for all branches according the effect info. + if (info.memory || info.load) { + AddMonadParameters(branches, "u", kUMonad->ToAbstract()); + } + if (info.io) { + AddMonadParameters(branches, "io", kIOMonad->ToAbstract()); + } + // Set merged effect info to both branches. + for (auto &branch : branches) { + branch->SetEffectInfo(info); + // Update caller if it is existed. + UpdateBranchCaller(branch); + } + } + + // Merge effect info for switch or switch_layer branch graphs. + EffectInfo MergeEffectInfo(const std::vector &branches) { + EffectInfo info = {EffectInfo::kDetected, false, false, false}; + for (auto &branch : branches) { + EffectInfo branch_info = GetEffectInfo(branch); + info.Merge(branch_info); + } + return info; + } + + // Trace a cnode for effect info. + EffectInfo TraceEffectInfo(const CNodePtr &cnode) { + auto prim = GetPrimitive(cnode); + if (IsPrimitiveEquals(prim, prim::kPrimSwitch)) { + // Special handling for Switch primitive. + return TraceSwitchEffectInfo(cnode); + } + + if (IsPrimitiveEquals(prim, prim::kPrimSwitchLayer)) { + // Special handling for SwitchLayer primitive. + return TraceSwitchLayerEffectInfo(cnode); + } + + if (IsPrimitiveEquals(prim, prim::kPrimTupleGetItem)) { + // Trace tuple_getitem. + std::stack tuple_indexes; + return TraceTupleGetItemEffectInfo(cnode, &tuple_indexes); + } + + // For high-order pritimive such as Partial, + // we trace effect info from its argument. + int index_prim = GetSideEffectPropagate(prim); + if (index_prim > 0 && index_prim < static_cast(cnode->size())) { + return TraceEffectInfo(cnode->input(static_cast(index_prim))); + } + + // For func graph calls, we trace effect info from graph output. + auto called_graph = GetFuncGraph(cnode); + if (called_graph) { + return TraceEffectInfo(called_graph->output()); + } + + // + // For ClassType as the input[0], if it is a primitive class + // with 'side_effect_propagate' attribute, we trace side effect + // from its argument indxed by the attribute value. + // + // e.g.: + // setpara = P.Partial()(P.Assign, self.para) + // setpara(x) + // + auto class_type = GetClassType(cnode); + if (class_type) { + int index = GetSideEffectPropagate(class_type); + if (index > 0 && index < static_cast(cnode->size())) { + return TraceEffectInfo(cnode->input(static_cast(index))); + } + } + + // Otherwise, no side effect found and stop trace. + return {EffectInfo::kDetected, false, false, false}; + } + + // Trace an ANFNode for effect info. + EffectInfo TraceEffectInfo(const AnfNodePtr &node) { + if (node) { + // Trace cnode. + auto cnode = node->cast(); + if (cnode) { + return TraceEffectInfo(cnode); + } + + // Trace parameter. + auto para = node->cast(); + if (para) { + return TraceEffectInfo(para); + } + + // Trace primitive. + auto prim = GetPrimitive(node); + if (prim) { + return GetPrimEffectInfo(prim); + } + + // Trace func graph. + auto value_node = node->cast(); + if (value_node && value_node->value()) { + auto graph = value_node->value()->cast(); + if (graph) { + return GetEffectInfo(graph); + } + } + } + // Something is wrong if we reached here. + MS_LOG(WARNING) << "EffectInfo untraceable: " << node->DebugString(2); + return {EffectInfo::kDetected, false, false, false}; + } + + int GetParameterIndex(const FuncGraphPtr &func_graph, const ParameterPtr ¶) { + int index = 0; + for (auto ¶meter : func_graph->parameters()) { + if (para == parameter) { + return index; + } + ++index; + } + MS_LOG(EXCEPTION) << "Parameter not found: " << (para ? para->DebugString() : ""); + } + + // Trace effect info from function parameter. + EffectInfo TraceEffectInfo(const ParameterPtr ¶) { + EffectInfo info{EffectInfo::kDetected, false, false, false}; + ForEachRealArguments(para, [this, &info](const AnfNodePtr &arg) { + // Merge caller input effect info. + auto input_info = TraceEffectInfo(arg); + info.Merge(input_info); + }); + return info; + } + + void ForEachRealArguments(const ParameterPtr ¶, std::function handler) { + auto func_graph = para->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + // Find index of the parameter, starts from 0. + const int para_index = GetParameterIndex(func_graph, para); + const size_t input_index = static_cast(para_index) + 1; + // Search user cnodes of the func graph. + auto &users = func_graph->func_graph_cnodes_index(); + if (users.empty()) { + MS_LOG(WARNING) << "Unused graph for parameter " << para->DebugString(); + } + for (auto &user : users) { + auto use_index = user.first->second; + if (use_index != 0) { + // Skip non-caller usage. + continue; + } + // Caller cnode. + auto cnode = dyn_cast(user.first->first); + if (cnode && input_index < cnode->size()) { + handler(cnode->input(input_index)); + } + } + } + + // For call node, returns effect info of the callee graph. + EffectInfo GetCallEffectInfo(const CNodePtr &cnode) { + constexpr size_t min_call_node_size = 2; + if (cnode->size() < min_call_node_size) { + MS_LOG(EXCEPTION) << "Invalid call node: " << cnode->DebugString(); + } + auto func_graph = GetValueNode(cnode->inputs().at(1)); + if (func_graph == nullptr) { + MS_LOG(EXCEPTION) << "Invalid call node: " << cnode->DebugString(); + } + return GetEffectInfo(func_graph); + } + + // Detect effect info by depth first search. + EffectInfo DetectEffectInfo(const CNodePtr &cnode) { + // For primitive, get effect info from its attributes and inputs. + auto prim = GetPrimitive(cnode); + if (prim) { + // Skip 'return' cnode. + if (IsPrimitiveEquals(prim, prim::kPrimReturn)) { + return {EffectInfo::kDetected, false, false, false}; + } + // Special handling for 'call' cnode. + if (IsPrimitiveEquals(prim, prim::kPrimCall)) { + return GetCallEffectInfo(cnode); + } + auto info = GetPrimEffectInfo(prim); + if (!info.memory && !IsKeepRef(prim)) { + // For primitive calls, if no memory effects but + // Ref parameter used, we will insert 'load' before them. + // Except for primitives like J(f) or Partial(f, x) which propagate side effect, + // load is inserted inside the func_graph f. + info.load = HasRefInput(cnode); + } + return info; + } + + // For func graph, detect effect info by its children cnodes. + auto func_graph = GetFuncGraph(cnode); + if (func_graph) { + return GetEffectInfo(func_graph); + } + + // When input[0] is a cnode, it is a function returned from + // a high-order function call, we trace it by return value. + auto func_cnode = GetFuncCNode(cnode); + if (func_cnode) { + caller_ = cnode; + return TraceEffectInfo(func_cnode); + } + + // When input[0] is a parameter, it is a function parameter for + // the high-order function, we trace it by caller. + auto func_para = GetFuncParameter(cnode); + if (func_para) { + return TraceEffectInfo(func_para); + } + + // When input[0] is a MultitypeFuncGraph, it's not specialized + // as one of its parameters is AbstractUndertermined, + // This MultitypeFuncGraph may be specialized at next Renormalize + // process, but we have to keep the order by insert UMonad now, + // otherwise order will be lost in next Renormalize. + // So assume it has memory side effect conservatively. + auto func_multitype = GetFuncMultitypeFuncGraph(cnode); + if (func_multitype) { + MS_LOG(DEBUG) << "Assume memory side effect for: " << cnode->DebugString(); + return {EffectInfo::kDetected, true, false, false}; + } + + MS_LOG(WARNING) << "Side effect undetectable: " << cnode->DebugString(2); + return {EffectInfo::kDetected, false, false, false}; + } + + // Gets EffectInfo for CNode. + EffectInfo GetEffectInfo(const CNodePtr &cnode) { + const auto &effect_info = cnode->GetEffectInfo(); + if (effect_info.state == EffectInfo::kDetected) { + // Effect info already detected, return it. + return effect_info; + } + + // Detect effect info for the cnode. + EffectInfo info = DetectEffectInfo(cnode); + if (info.state == EffectInfo::kDetected) { + // Save detected info into cnode. + cnode->SetEffectInfo(info); + } + return info; + } + + // Gets SCC that the given graph belongs to. + const SccPtr &GetScc(const FuncGraphPtr &func_graph) const { + auto found = scc_map_.find(func_graph); + if (found == scc_map_.end()) { + MS_LOG(EXCEPTION) << "SCC not found for " << func_graph->ToString(); + } + return found->second; + } + + // Set effect info for all member graphs in the SCC. + void SetSccEffectInfo(const SccPtr &scc, const EffectInfo &info) { + for (auto &g : *scc) { + g->SetEffectInfo(info); + } + } + + // Gets EffectInfo for func graph. + EffectInfo GetEffectInfo(const FuncGraphPtr &func_graph) { + const auto &effect_info = func_graph->GetEffectInfo(); + if (effect_info.state != EffectInfo::kUnknown) { + // Effect info already set, return it. + return effect_info; + } + // Get SCC that this graph belongs to. + auto &scc = GetScc(func_graph); + // To prevent SCC members be visited again, we set effect info + // to 'kDetecting' state before start to check cnodes. + EffectInfo info{EffectInfo::kDetecting, false, false, false}; + SetSccEffectInfo(scc, info); + // Check side effects for all cnodes in the SCC. + std::vector undetected; + for (auto &g : *scc) { + for (auto &cnode : g->order_list()) { + auto cnode_effect = GetEffectInfo(cnode); + if (cnode_effect.state != EffectInfo::kDetected) { + // For side effect undetected node, it could be a call to the SCC member graph, + // we will try to check side effect again after SCC side effect detected. + undetected.push_back(cnode); + } + // Merge effect info from the node. + info.Merge(cnode_effect); + } + // Make sure all sub-graphs is checked. since some sub-graphs may not directly called, + // for example: return ValueNode(sub_graph). + for (auto &sg : g->func_graphs_used()) { + (void)GetEffectInfo(sg.first); + } + } + // Update effect into for all members of the SCC. + info.state = EffectInfo::kDetected; + SetSccEffectInfo(scc, info); + // Check undetected cnodes again after side effect of the SCC is detected. + for (auto &cnode : undetected) { + auto cnode_effect = GetEffectInfo(cnode); + // Side effect should be detected now. + if (cnode_effect.state != EffectInfo::kDetected) { + MS_LOG(EXCEPTION) << "Side effect is undectable: " << cnode->DebugString(); + } + } + // graph which need PipelineSplit doesn't have effect. + if (func_graph->stage() != -1) { + info.memory = false; + info.load = false; + info.io = false; + } + return info; + } + + void SaveBranchCaller(const CNodePtr &switch_node, const FuncGraphPtr &branch) { + auto manager = branch->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + auto found = node_users.find(switch_node); + if (found == node_users.end()) { + MS_LOG(WARNING) << "Caller not found for " << switch_node->DebugString(); + return; + } + if (found->second.size() != 1) { + MS_LOG(WARNING) << "Wrong callers " << found->second.size() << " for " << switch_node->DebugString(); + return; + } + auto &user = *found->second.begin(); + auto cnode = dyn_cast(user.first); + if (cnode != nullptr || user.second == 0) { + branch_caller_map.emplace(branch, cnode); + } + } + + void UpdateBranchCaller(const FuncGraphPtr &branch) { + auto iter = branch_caller_map.find(branch); + if (iter == branch_caller_map.end()) { + return; + } + const auto &caller = iter->second; + const auto &info = branch->GetEffectInfo(); + AddMonadForCaller(caller, info); + } + + void AddMonadForCaller(const CNodePtr &caller, const EffectInfo &info) { + if (info.memory || info.load) { + // Add u monad argument to caller if need. + AddMonadArgument(caller, kUMonad); + } + if (info.io) { + // Add io monad argument to caller if need. + AddMonadArgument(caller, kIOMonad); + } + } + + void AddMonadArgument(const CNodePtr &cnode, const ValuePtr &monad) { + auto monad_abs = monad->ToAbstract(); + for (size_t i = 1; i < cnode->size(); ++i) { + auto abs = cnode->inputs().at(i)->abstract(); + if (abs != nullptr && *abs == *monad_abs) { + // Skip if monad argument already existed. + return; + } + } + // Add monad argument if not yet. + auto monad_input = NewValueNode(monad); + monad_input->set_abstract(monad_abs); + if ((monad == kUMonad) && cnode->size() > 1 && HasAbstractIOMonad(cnode->inputs().back())) { + // Insert u monad before io monad. + size_t last_index = cnode->size() - 1; + cnode->add_input(cnode->input(last_index)); + cnode->set_input(last_index, monad_input); + } else { + // Add monad as the last input. + cnode->add_input(monad_input); + } + } + + private: + // The root graph. + FuncGraphPtr root_; + + // SCC map. + SccMap scc_map_; + + // Single branch (in switch) and its caller cnode. + std::map branch_caller_map; + + // Current high order func caller cnode. + CNodePtr caller_ = nullptr; + + // switch_layer_calls save all switch_layer calls, so that + // we can check whether monad argument should be added for them. + std::vector switch_layer_calls; +}; // class SideEffectFinder + +// -------------------------------------------------------------------- +// AutoMonadConverter converts side-effect cnodes into monad form. +// -------------------------------------------------------------------- +class AutoMonadConverter { + public: + static bool Handle(const FuncGraphPtr &func_graph, bool top) { + AutoMonadConverter converter(func_graph, top); + return converter.Run(); + } + + private: + AutoMonadConverter(const FuncGraphPtr &func_graph, bool top) + : func_graph_(func_graph), manager_(func_graph->manager()), top_(top) {} + + ~AutoMonadConverter() = default; + + bool Run() { + // Handle cnodes if graph has side effects. + if (HasSideEffects()) { + HandleCNodes(); + } + // Clean up after conversion finished. + func_graph_->ClearIsolateNodes(); + func_graph_->ClearOrderList(); + return has_effect_cnodes_; + } + + // Check if there are side effects from effect info. + static bool HasSideEffects(const EffectInfo &info) { return (info.memory || info.io || info.load); } + + // Check if current graph has side effects. + bool HasSideEffects() const { + const auto &info = func_graph_->GetEffectInfo(); + if (info.state != EffectInfo::kDetected) { + // Effect info should have been set by SideEffectFinder, except unused graphs. + MS_LOG(INFO) << "No effect info for unused graph: " << func_graph_->ToString(); + return false; + } + return HasSideEffects(info); + } + + // Gets effect info for a cnode. + const EffectInfo &GetEffectInfo(const CNodePtr &cnode) { + auto &effect_info = cnode->GetEffectInfo(); + if (effect_info.state != EffectInfo::kDetected) { + // Effect info should have been set by SideEffectFinder. + MS_LOG(EXCEPTION) << "Side effects not detected: " << cnode->DebugString(); + } + return effect_info; + } + + // + // Handle CNodes for side effects. + // + void HandleCNodes() { + // Check whether UpdateState and Depend are required. + bool update_state = NeedUpdateState(); + + // Check all cnodes in order list. + for (auto &cnode : func_graph_->order_list()) { + auto &info = GetEffectInfo(cnode); + has_effect_cnodes_ = (has_effect_cnodes_ || HasSideEffects(info)); + if (cnode->func_graph() != func_graph_) { + // Handle outer cnode. + HandleOuterNode(cnode, info); + } else { + // Handle cnode with memory side effects. + if (info.memory) { + HandleMemoryEffects(cnode, update_state); + } else if (info.load) { + // If no memory side effects, handle load if need. + HandleLoad(cnode, update_state); + } + // Handle cnode with IO side effects. + if (info.io) { + HandleIoEffects(cnode, update_state); + } + } + cnode->SetEffectHandled(true); + } + // Insert Depend nodes for states if required. + if (update_state) { + InsertStateDepends(); + } + } + + void HandleOuterNode(const CNodePtr &cnode, const EffectInfo &info) { + if (info.memory || info.load) { + (void)GetUniverse(); + bool load_with_primitive = (info.load && IsPrimitiveCNode(cnode)); + if (!cnode->IsEffectHandled() && !load_with_primitive) { + auto u = NewValueNode(kUMonad); + u->set_abstract(kUMonad->ToAbstract()); + cnode->add_input(u); + } + } + if (info.io) { + (void)GetIoState(); + if (!cnode->IsEffectHandled()) { + auto io = NewValueNode(kIOMonad); + io->set_abstract(kIOMonad->ToAbstract()); + cnode->add_input(io); + } + } + } + + // + // Convert cnode with memory side effect to monad form, + // from: + // output = func(input) + // to: + // output = func(input, u) + // u = UpdateState(u, output) # if update_state is true + // + void HandleMemoryEffects(const CNodePtr &cnode, bool update_state) { + const auto &u = GetUniverse(); + AddMonadInput(cnode, u); + if (update_state) { + u_ = UpdateState(u, cnode); + } + } + + // + // Convert cnode with io side effect to monad form, + // from: + // output = func(input) + // to: + // output = func(input, io) + // io = UpdateState(io, output) # if update_state is true + // + void HandleIoEffects(const CNodePtr &cnode, bool update_state) { + const auto &io = GetIoState(); + AddMonadInput(cnode, io); + if (update_state) { + io_ = UpdateState(io, cnode); + } + } + + void HandleLoad(const CNodePtr &cnode, bool update_state) { + auto value = GetValueNode(cnode->input(0)); + if (value && value->isa()) { + // For primitive calls that use Ref as input, insert Loads before them. + InsertLoads(cnode, update_state); + } else { + // For non-primitive calls, load is used inside the callee, + // We do not insert load for it but handle it as a side + // effects cnode. + HandleMemoryEffects(cnode, update_state); + } + } + + // + // Insert Loads for a primitive cnode that use Ref as input. + // for example, from: + // out = Prim(self.para1, self.para2, other_args) + // to: + // p1 = Load(self.para1, u) + // p2 = Load(self.para2, u) + // t = make_tuple(p1, p2) # if update_state + // u1 = UpdateState(u, t) # is required + // out = Prim(p1, p2, other_args) + // + void InsertLoads(const CNodePtr &cnode, bool update_state) { + // Find ref inputs. + auto ref_inputs = GetRefInputs(cnode); + if (ref_inputs.empty()) { + MS_LOG(WARNING) << "Ref input not found for load insertion: " << cnode->DebugString(); + return; + } + // Current u monad. + auto u = GetUniverse(); + // Create Load cnodes. + auto loads = MakeLoads(cnode, ref_inputs, u); + if (loads.empty() || !update_state) { + // Skip UpdateState insertion. + return; + } + // Insert UpdateState if required. + if (loads.size() == 1) { + // One Load, no make_tuple needed. + u_ = UpdateState(u, loads.front()); + return; + } + // Multiple Loads, Create a MakeTuple before UpdateState. + abstract::AbstractBasePtrList load_abstracts; + std::transform(loads.begin(), loads.end(), std::back_inserter(load_abstracts), + [](const AnfNodePtr &load) { return load->abstract(); }); + loads.insert(loads.begin(), NewValueNode(prim::kPrimMakeTuple)); + auto make_tuple = func_graph_->NewCNode(loads); + make_tuple->set_abstract(std::make_shared(load_abstracts)); + u_ = UpdateState(u, make_tuple); + } + + std::vector MakeLoads(const CNodePtr &cnode, const RefInputs &ref_inputs, const AnfNodePtr &u) { + std::vector loads; + for (auto &ref_input : ref_inputs) { + // Make a Load cnode for ref input. + auto &ref = ref_input.first; + auto load = MakeLoad(cnode, ref, u); + // Replace input with the load cnode. + for (size_t index : ref_input.second) { + manager_->SetEdge(cnode, index, load); + } + loads.emplace_back(std::move(load)); + } + return loads; + } + + CNodePtr MakeLoad(const CNodePtr &cnode, const AnfNodePtr &ref, const AnfNodePtr &u) { + static const std::string primitive_target = "primitive_target"; + // Create Load cnode. + auto load_prim = NewValueNode(prim::kPrimLoad); + auto load_cnode = func_graph_->NewCNode({load_prim, ref, u}); + // Set device target for Load CNode. + std::string target = GetCNodeTarget(cnode); + load_cnode->set_user_data(primitive_target, std::make_shared(target)); + // Set load_cnode abstract to Tensor according the input Ref[Tensor]. + auto ref_abs = dyn_cast(ref->abstract()); + MS_EXCEPTION_IF_NULL(ref_abs); + load_cnode->set_abstract(ref_abs->CloneAsTensor()); + return load_cnode; + } + + // Add or replace monad input. + void AddMonadInput(const CNodePtr &cnode, const AnfNodePtr &monad) { + constexpr size_t max_monad_inputs = 2; + auto monad_abs = monad->abstract(); + auto &inputs = cnode->inputs(); + int last = static_cast(inputs.size()) - 1; + int stop = last - max_monad_inputs; + // Search monad in inputs, replace it if found. + for (int i = last; i > 0 && i > stop; --i) { + size_t index = static_cast(i); + auto input_abs = inputs[index]->abstract(); + if (input_abs && *input_abs == *monad_abs) { + manager_->SetEdge(cnode, i, monad); + return; + } + } + // If monad not found in inputs, add a monad input. + manager_->AddEdge(cnode, monad); + } + + void InsertStateDepends() { + if (u_) { + // Insert Depend node for UMonad, + // Gradient is required for memory side effects. + InsertStateDepend(u_); + } + if (io_) { + // No gradient required for IO operations. + InsertStateDepend(io_); + } + } + + void InsertStateDepend(const AnfNodePtr &state) { + // Insert Depend node and set it as output. + auto depend = NewValueNode(prim::kPrimDepend); + auto output = GetGraphOutput(); + auto depend_cnode = func_graph_->NewCNode({depend, output, state}); + depend_cnode->set_abstract(output->abstract()); + func_graph_->set_output(depend_cnode); + } + + AnfNodePtr GetGraphOutput() { + auto output = func_graph_->output(); + if (output != nullptr) { + return output; + } + return NewValueNode(kNone); + } + + AnfNodePtr UpdateState(const AnfNodePtr &state, const AnfNodePtr &attach) { + auto update_state = NewValueNode(prim::kPrimUpdateState); + auto update_state_cnode = func_graph_->NewCNode({update_state, state, attach}); + update_state_cnode->set_abstract(state->abstract()); + return update_state_cnode; + } + + AnfNodePtr &GetUniverse() { + if (u_ == nullptr) { + if (top_) { + u_ = NewValueNode(kUMonad); + u_->set_abstract(kUMonad->ToAbstract()); + } else { + u_ = AddMonadParameter(func_graph_, "u", kUMonad->ToAbstract()); + } + } + return u_; + } + + AnfNodePtr &GetIoState() { + if (io_ == nullptr) { + if (top_) { + io_ = NewValueNode(kIOMonad); + io_->set_abstract(kIOMonad->ToAbstract()); + } else { + io_ = AddMonadParameter(func_graph_, "io", kIOMonad->ToAbstract()); + } + } + return io_; + } + + // Return true if update_state should be used in this func graph. + // In some case, update_state can be omitted, such as: + // def side_effect_tail_call(args): + // a = pure_func(args) + // return side_effect_call(a) + bool NeedUpdateState() { + // Search for the only one side effect cnode. + CNodePtr side_effect_cnode = nullptr; + for (auto &cnode : func_graph_->order_list()) { + if (HasSideEffect(cnode)) { + if (side_effect_cnode != nullptr) { + // There are multiple side effect cnodes, update state is required. + return true; + } + side_effect_cnode = cnode; + } + } + if (side_effect_cnode == nullptr) { + // No side effect cnode, no update state. + return false; + } + if (IsPrimitiveCNode(side_effect_cnode)) { + // Always add update_state for primitive cnode. + return true; + } + // If the only side effect cnode is not the tail call, update_state is required. + return func_graph_->output() != side_effect_cnode; + } + + bool HasSideEffect(const CNodePtr &cnode) { + const auto &info = GetEffectInfo(cnode); + return (info.memory || info.load || info.io); + } + + private: + // The func graph to be converted. + const FuncGraphPtr &func_graph_; + + // The func graph manager, used for graph edge update. + FuncGraphManagerPtr manager_; + + // True if converting top graph. + const bool top_; + + // True if there are side effect cnodes within this func graph. + bool has_effect_cnodes_ = false; + + // Current memory state node, null if no memory side effects. + AnfNodePtr u_; + + // Current IO state node, null if no IO side effects. + AnfNodePtr io_; +}; // class AutoMonadConverter + +} // namespace + +// Entry point of the auto-monad phase, +// the func_graph should be resolved and infer is done. +// return true if side effect nodes found in func_graph. +bool AutoMonad(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(func_graph->manager()); + + // Search and mark side effects for the graph and sub-graphs. + // this should be called before auto-monad starts. + SideEffectFinder::Search(func_graph); + + // Execute auto-monad conversion on top graph. + bool has_effects = AutoMonadConverter::Handle(func_graph, true); + // Convert used sub-graphs. + auto fg_used_total = func_graph->func_graphs_used_total(); + for (auto &fg : fg_used_total) { + auto top_flag = fg->has_flag(mindspore::kFuncGraphFlagBackPropEntry); + if (fg->stage() != -1) { + top_flag = true; + } + bool fg_has_effects = AutoMonadConverter::Handle(fg, top_flag); + has_effects = has_effects || fg_has_effects; + } + + // Clear isolate nodes after auto-monad finished. + auto manager = func_graph->manager(); + if (manager) { + manager->ClearIsolateNodes(); + } + return has_effects; +} + +bool ReAutoMonad(const FuncGraphPtr &func_graph) { + // AutoMonad for bprop network, only Monad for func graphs which back propogators have side effects. + // Or AutoMonad for MultitypeFuncGraph which specialized in Renormalize other than the first Specialize pass. + bool need_auto_monad = false; + std::vector auto_monaded_fg; + func_graph->EraseUnusedNodeInOrder(); + for (auto &fg : func_graph->func_graphs_used_total()) { + if (fg->has_flag(mindspore::kFuncGraphFlagReAutoMonad)) { + auto_monaded_fg.push_back(fg); + for (auto &used_fg : fg->func_graphs_used_total()) { + used_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true); + auto_monaded_fg.push_back(used_fg); + } + need_auto_monad = true; + MS_LOG(DEBUG) << "AutoMonad Grad for func graph: " << fg->ToString(); + } + fg->EraseUnusedNodeInOrder(); + } + bool changed = false; + if (need_auto_monad) { + for (auto &fg : func_graph->func_graphs_used_total()) { + if (!fg->has_flag(mindspore::kFuncGraphFlagReAutoMonad)) { + fg->ClearOrderList(); + fg->ClearIsolateNodes(); + } + } + changed = AutoMonad(func_graph); + for (auto &fg : auto_monaded_fg) { + fg->erase_flag(mindspore::kFuncGraphFlagReAutoMonad); + } + // After auto monad, Order List and Isolate nodes in graph and manager will be cleared. + } else { + func_graph->ClearOrderList(); + func_graph->ClearIsolateNodes(); + for (auto &fg : func_graph->func_graphs_used_total()) { + fg->ClearOrderList(); + fg->ClearIsolateNodes(); + } + MS_EXCEPTION_IF_NULL(func_graph->manager()); + func_graph->manager()->ClearIsolateNodes(); + } + return changed; +} +} // namespace mindspore::pipeline diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.h b/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.h new file mode 100644 index 0000000000..238e3665be --- /dev/null +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/auto_monad.h @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_AUTO_MONAD_H_ +#define MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_AUTO_MONAD_H_ + +#include +#include + +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "base/effect_info.h" + +namespace mindspore::pipeline { + +// Run auto-monad, handle side-effects, called from frontend pipeline. +bool AutoMonad(const FuncGraphPtr &func_graph); + +// Run auto-monad after grad or Renormalize, handle side-effects, called from frontend opt pass. +bool ReAutoMonad(const FuncGraphPtr &func_graph); +} // namespace mindspore::pipeline + +#endif // MINDSPORE_CCSRC_PIPELINE_JIT_PARSE_AUTO_MONAD_H_ diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc index b290f94bad..2ab2520450 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/evaluator.cc @@ -97,29 +97,53 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr << MsContext::GetInstance()->get_param(MS_CTX_MAX_CALL_DEPTH) << ", please call 'context.set_context(max_call_depth=value)' to adjust this value."; } + // Analysis for isolate nodes first, as some validation check in FuncGraph is isolate nodes; + for (const auto &node : fg->GetIsolateNodesInOrder()) { + AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); + MS_LOG(DEBUG) << "Analysis isolate_node begin, func graph: " << fg.get() << fg->ToString() + << ", node_conf: " << node_conf->ToString(); + auto isolate_base = engine->GetEvaluatedValue(node_conf)->abstract(); + MS_LOG(DEBUG) << "Analysis isolate_node end, func graph: " << fg.get() << fg->ToString() + << ", node_conf: " << node_conf->ToString() << ", abstract: " << isolate_base->ToString(); + } + const auto &all_nodes = TopoSort(func_node, SuccIncoming, [&fg](const AnfNodePtr &node) -> IncludeType { if (node->func_graph() != fg || node->isa()) { return EXCLUDE; } return FOLLOW; }); + bool isolate_node_propagate_flag = false; for (const auto &node : all_nodes) { AnfNodeConfigPtr node_conf = engine->MakeConfig(node, graph_context_); MS_LOG(DEBUG) << "Analysis node begin, func graph: " << fg.get() << fg->ToString() << ", node_conf: " << node_conf->ToString(); - ret_base = engine->GetEvaluatedValue(node_conf)->abstract(); + auto node_eval_result = engine->GetEvaluatedValue(node_conf); + ret_base = node_eval_result->abstract(); MS_LOG(DEBUG) << "Analysis node end, func graph: " << fg.get() << fg->ToString() << ", node_conf: " << node_conf->ToString() << ", abstract: " << ret_base->ToString(); + if (node->isa()) { + isolate_node_propagate_flag |= node_eval_result->HasIsolateNodesPropagateCNodeFlag(); + MS_LOG(DEBUG) << "Check isolate_nodes flag for node: " << node->DebugString() + << ", abstract: " << ret_base->ToString() + << ", flag: " << node_eval_result->HasIsolateNodesPropagateCNodeFlag(); + } } engine->DecreaseFunctionCallDepth(); MS_EXCEPTION_IF_NULL(ret_base); MS_LOG(DEBUG) << "BaseFuncGraph " << fg->ToString() << " eval end, evaluated abstract: " << ret_base->ToString() << ", is stub: " << fg->stub(); + if (fg->stub()) { - return std::make_shared(std::make_shared(), nullptr); + ret_base = std::make_shared(); + } + auto eval_result = std::make_shared(ret_base, std::make_shared()); + if (isolate_node_propagate_flag) { + eval_result->SetIsolateNodesPropagateCNodeFlag(true); + eval_result->SetIsolateNodesPropagateFuncGraphFlag(true); } - return std::make_shared(ret_base, nullptr); + return eval_result; } AbstractBasePtrList FuncGraphEvaluator::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 830e94ee90..aeff8804af 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -46,8 +46,8 @@ namespace mindspore { namespace abstract { using mindspore::parse::PyObjectWrapper; -std::unordered_set prims_to_skip_undetermined_infer{"make_tuple", "make_list", "switch", "env_setitem", - "env_getitem"}; +std::unordered_set prims_to_skip_undetermined_infer{ + "make_tuple", "make_list", "switch", "env_setitem", "env_getitem", "Load", "UpdateState"}; EvalResultPtr DoSignatureEvaluator::Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list, AnfNodeConfigPtr out_conf) { @@ -187,7 +187,7 @@ AnfNodePtr MixedPrecisionCastHelper(const AnfNodePtr &source_node, const Abstrac if (x->element()->BuildType()->isa()) { auto cast = prim::GetPythonOps("cast", "mindspore.ops.functional"); MS_EXCEPTION_IF_NULL(cast); - target_node = func_graph->NewCNode({NewValueNode(cast), source_node, target_type}); + target_node = func_graph->NewCNodeAfter(source_node, {NewValueNode(cast), source_node, target_type}); } } else if (node_type->isa()) { auto x = node_type->cast(); @@ -442,6 +442,10 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { dic[ATTR_SHAPE] = py::none(); dic[ATTR_DTYPE] = arg->BuildType(); dic[ATTR_VALUE] = py::none(); + } else if (abs_base->isa()) { + dic[ATTR_SHAPE] = py::none(); + dic[ATTR_DTYPE] = abs_base->BuildType(); + dic[ATTR_VALUE] = py::none(); } else { auto value = abs_base->BuildValue(); if ((*value == *kAnyValue)) { @@ -472,8 +476,10 @@ py::tuple PreparePyInputs(const PrimitivePyPtr &prim_py, const AbstractBasePtrLi args_ptr = &args; } - py::tuple py_args(args_ptr->size()); - for (size_t i = 0; i < args_ptr->size(); i++) { + // The monad parameter is defined at the end of the parameter and needs to be ignored + std::size_t size_args = args_ptr->size() - GetAbstractMonadNum(*args_ptr); + py::tuple py_args(size_args); + for (size_t i = 0; i < size_args; i++) { auto arg_i = (*args_ptr)[i]; py_args[i] = ConvertAbstractToPython(arg_i); } @@ -582,7 +588,8 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c AbstractBasePtr abs_base = eval_impl_(engine, prim_, args); prim_->EndRecordAddAttr(); auto added_attrs = prim_->evaluate_added_attrs(); - return std::make_shared(abs_base, std::make_shared(added_attrs)); + auto eval_result = std::make_shared(abs_base, std::make_shared(added_attrs)); + return eval_result; } EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const AbstractBasePtrList &args) { @@ -800,13 +807,17 @@ EvalResultPtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &engin // item_name to func addr from obj_map parse::SymbolPtr symbol = item_v->cast(); parse::NameSpacePtr name_space = data_v->cast(); - FuncGraphPtr func_graph = out_conf->node()->func_graph(); + auto out_node = out_conf->node(); + FuncGraphPtr func_graph = out_node->func_graph(); - auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_conf->node()); + auto new_node = parse::ResolveSymbol(func_graph->manager(), name_space, symbol, out_node); if (new_node == nullptr) { MS_LOG(EXCEPTION) << "Resolve node failed"; } + // Replace old node with the resolved new node in order list. + func_graph->ReplaceInOrder(out_node, new_node); + AnalysisEnginePtr eng = out_conf->engine(); AnfNodeConfigPtr fn_conf = eng->MakeConfig(new_node, out_conf->context()); return eng->ForwardConfig(out_conf, fn_conf); @@ -1114,7 +1125,7 @@ class CreateInstanceEvaluator : public TransitionPrimEvaluator { } AbstractBasePtr ret = ToAbstract(converted_ret, AnalysisContext::DummyContext(), out_conf); - auto infer_result = std::make_shared(ret, nullptr); + auto infer_result = std::make_shared(ret, std::make_shared()); (*cache_)[args_spec_list] = infer_result; return infer_result; } @@ -1171,9 +1182,11 @@ class PartialEvaluator : public Evaluator { } } - (void)std::transform( - args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue()->abstract(); }); + std::vector eval_result_list; + (void)std::transform(args_conf_list.cbegin() + 1, args_conf_list.cend(), std::back_inserter(eval_result_list), + [](const ConfigPtr &config) -> EvalResultPtr { return config->GetEvaluatedValue(); }); + (void)std::transform(eval_result_list.cbegin(), eval_result_list.cend(), std::back_inserter(args_spec_list), + [](const EvalResultPtr &eval_result) -> AbstractBasePtr { return eval_result->abstract(); }); AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end()); auto cnode = out_conf->node()->cast(); @@ -1183,17 +1196,25 @@ class PartialEvaluator : public Evaluator { << ", args_conf_list: " << mindspore::ToString(args_conf_list); } + auto flag = std::any_of(eval_result_list.cbegin(), eval_result_list.cend(), [](const EvalResultPtr &eval_result) { + MS_LOG(DEBUG) << "Propagate isolate nodes flag from: " << eval_result->abstract()->ToString() + << ", flag: " << eval_result->HasIsolateNodesPropagateCNodeFlag(); + return eval_result->HasIsolateNodesPropagateCNodeFlag(); + }); AbstractFuncAtomPtrList partial_funcs_list; - auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) { + auto build_partial = [args, cnode, flag, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) { auto new_func = std::make_shared(atom_func, args, cnode); partial_funcs_list.push_back(new_func); + if (atom_func->HasIsolateNodesFlag() || flag) { + new_func->SetIsolateNodesFlag(true); + } }; func->Visit(build_partial); auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list); - auto infer_result = std::make_shared(ret, std::make_shared()); - (*cache_)[args_spec_list] = infer_result; - return infer_result; + auto eval_result = std::make_shared(ret, std::make_shared()); + (*cache_)[args_spec_list] = eval_result; + return eval_result; } EvalResultPtr Eval(AnalysisEnginePtr, const AbstractBasePtrList &) override { diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc index 5642c39c7a..4b0de89e7c 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc @@ -101,9 +101,6 @@ FuncGraphSpecializer::FuncGraphSpecializer(ProgramSpecializer *const s, const Fu cloner_ = SpecializerClone(fg, std::make_shared(GetNextCounter())); repl_node_ = cloner_->cloned_node(); specialized_func_graph_ = cloner_->cloned_func_graph()[fg]; - todo_.push_back(fg->get_return()); - auto ps = fg->parameters(); - (void)todo_.insert(todo_.end(), ps.begin(), ps.end()); } AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &node) { @@ -131,12 +128,24 @@ AnfNodePtr FuncGraphSpecializer::ReplicateDisconnectedNode(const AnfNodePtr &nod } auto c_node = node->cast(); MS_EXCEPTION_IF_NULL(c_node); - auto inputs = c_node->inputs(); - std::vector new_inputs; - (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(new_inputs), - [this](const AnfNodePtr &inp) -> AnfNodePtr { return ReplicateDisconnectedNode(inp); }); auto c_new_node = new_node->cast(); MS_EXCEPTION_IF_NULL(c_new_node); + auto inputs = c_node->inputs(); + std::vector new_inputs; + (void)std::transform( + inputs.begin(), inputs.end(), std::back_inserter(new_inputs), [this](const AnfNodePtr &inp) -> AnfNodePtr { + auto new_inp = ReplicateDisconnectedNode(inp); + // refer the comments in BuildReplacedNode. + if (inp->isa()) { + auto c_inp = inp->cast(); + MS_EXCEPTION_IF_NULL(c_inp); + auto c_new_inp = new_inp->cast(); + MS_EXCEPTION_IF_NULL(c_new_inp); + MS_LOG(DEBUG) << "Replace inp node: " << inp->ToString() << " in order list, with " << new_inp->ToString(); + c_new_inp->func_graph()->ReplaceInOrder(c_inp, c_new_inp); + } + return new_inp; + }); c_new_node->set_inputs(new_inputs); } @@ -180,7 +189,16 @@ void FuncGraphSpecializer::Run() { } void FuncGraphSpecializer::FirstPass() { - while (todo_.size()) { + // Process parameter; + for (const auto &node : func_graph_->parameters()) { + (void)marked_.insert(node); + ProcessNode(node); + } + ProcessIsolateNodes(); + + todo_.push_back(func_graph_->get_return()); + + while (!todo_.empty()) { AnfNodePtr node = todo_.back(); todo_.pop_back(); if (node->func_graph() == nullptr) { @@ -209,13 +227,25 @@ void FuncGraphSpecializer::FirstPass() { // Specialize CNode in func graphs void FuncGraphSpecializer::SecondPass() { - for (auto &node : BroadFirstSearchGraphCNodes(specialized_func_graph_->get_return())) { + std::vector starts; + auto &isolate_nodes = specialized_func_graph_->isolate_nodes(); + starts.reserve(isolate_nodes.size() + 1); + starts.push_back(specialized_func_graph_->get_return()); + (void)std::transform(isolate_nodes.begin(), isolate_nodes.end(), std::back_inserter(starts), + [](auto &node) { return dyn_cast(node); }); + for (auto &node : BroadFirstSearchGraphCNodes(starts)) { if (node->isa()) { ProcessCNode(node->cast()); } } } +static AnfNodePtr CreateNoBroadenDepend() { + PrimitivePtr prim = std::make_shared(prim::kPrimDepend->name(), prim::kPrimDepend->attrs()); + prim->set_attr(ATTR_NO_BROADEN, prim::kValueOne); + return BuildValueNode(prim, FromValueInside(prim)); +} + void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); ScopeGuard scope_guard(node->scope()); @@ -248,19 +278,32 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { for (size_t i = 0; i < old_inputs.size(); ++i) { auto node_input = old_inputs[i]; AnfNodeConfigPtr iconf = MakeConfig(node_input); - AbstractBasePtr ival = GetEvaluatedValueWrap(iconf); + auto eval_result = iconf->GetEvaluatedValue(); + AbstractBasePtr ival = eval_result->abstract(); // First try to check if node_input can be replaced by a ValueNode. If cannot, then try to check if // can be replaced by another CNode from anfnode_config_map, otherwise use the replicated node. AnfNodePtr replace_node = BuildPossibleValueNode(iconf->node(), ival, attrs); if (replace_node == nullptr) { - replace_node = BuildReplacedNode(iconf); + replace_node = BuildReplacedNode(iconf).second; MS_EXCEPTION_IF_NULL(replace_node); replace_node->set_abstract(ival); MS_LOG(DEBUG) << "Set replaced: " << replace_node->ToString() << ", to abstract: " << ival->ToString(); + } else if (node_input->isa() && eval_result->HasIsolateNodesPropagateCNodeFlag() && + node_input != node_input->func_graph()->output()) { + // Handle isolate nodes + auto inp_c_node = node_input->cast(); + auto collected = CollectCNodeWithIsolateNodes(inp_c_node, eval_result, c_new->func_graph()); + auto depend_ops = CreateNoBroadenDepend(); + AnfNodePtr new_cnode = specialized_func_graph_->NewCNode({depend_ops, replace_node, collected}); + new_cnode->set_abstract(ival); + replace_node = new_cnode; + MS_LOG(DEBUG) << "Build possible depend node for node: " << node_input->DebugString() + << ", ival: " << ival->ToString() << ", replace_node: " << replace_node->DebugString(); } else { - MS_LOG(DEBUG) << "Build possible value node for node: " << node_input->DebugString() - << ", ival: " << ival->ToString() << ", replace_node: " << replace_node->ToString(); + MS_LOG(DEBUG) << "Not set replace value node for node: " << node_input->DebugString() + << ", ival: " << ival->ToString() << ", replace_node: " << replace_node->DebugString(); } + if (new_inputs[i] != replace_node) { new_inputs[i] = replace_node; MS_LOG(DEBUG) << "Set new_input[" << i << "] = " << replace_node->DebugString(); @@ -270,19 +313,141 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { } } -AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) { +AnfNodePtr FuncGraphSpecializer::CollectCNodeWithIsolateNodes(const CNodePtr &c_node, + const EvalResultPtr &c_node_eval_result, + const FuncGraphPtr &new_fg) { + auto c_node_inputs = c_node->inputs(); + auto inp0 = c_node_inputs[0]; + auto inp0_conf = MakeConfig(inp0); + auto inp0_eval_result = inp0_conf->GetEvaluatedValue(); + auto inp0_abstract = inp0_eval_result->abstract(); + + auto inp0_abs_func = inp0_abstract->cast(); + if (inp0_abs_func == nullptr) { + MS_LOG_EXCEPTION << "inp0 should be AbstractFunction, but: " << inp0_abstract->ToString(); + } + + if (c_node_eval_result->HasIsolateNodesPropagateFuncGraphFlag() || inp0_abs_func->HasIsolateNodesFlag()) { + auto c_node_conf = MakeConfig(c_node); + auto replace_node = BuildReplacedNode(c_node_conf).second; + MS_EXCEPTION_IF_NULL(replace_node); + replace_node->set_abstract(inp0_abstract); + MS_LOG(DEBUG) << "Build possible depend node for node: " << c_node->DebugString() + << ", depend node: " << replace_node->DebugString(); + return replace_node; + } + + // Search inputs from 1 to find CNodeWithIsolateNode if that input is CNode and can Built PossibleValueNode. + std::vector collected_nodes; + for (std::size_t i = 1; i < c_node_inputs.size(); ++i) { + auto inp_i = c_node_inputs[i]; + if (inp_i->isa()) { + auto inp_i_conf = MakeConfig(inp_i); + auto inp_i_eval_result = inp_i_conf->GetEvaluatedValue(); + auto inp_i_abstract = inp_i_eval_result->abstract(); + if (inp_i_eval_result->HasIsolateNodesPropagateCNodeFlag()) { + static auto attrs = std::make_shared(); + AnfNodePtr replace_node = BuildPossibleValueNode(inp_i, inp_i_abstract, attrs); + if (replace_node == nullptr) { + replace_node = BuildReplacedNode(inp_i_conf).second; + MS_EXCEPTION_IF_NULL(replace_node); + replace_node->set_abstract(inp_i_abstract); + MS_LOG(DEBUG) << "Set replaced: " << replace_node->DebugString() << ", to replace: " << c_node->DebugString(); + } else { + auto inp_i_c_node = inp_i->cast(); + AnfNodePtr new_node = GetReplicatedNode(inp_i_c_node); + auto collected = CollectCNodeWithIsolateNodes(inp_i_c_node, inp_i_eval_result, new_node->func_graph()); + replace_node = collected; + } + collected_nodes.push_back(replace_node); + } + } + } + // Build depend node; + if (collected_nodes.empty()) { + MS_LOG_EXCEPTION << "cannot find where IsolateNodes from, node: " << c_node->DebugString() + << ", abstract: " << c_node_eval_result->abstract()->ToString() + << ", flag: " << c_node_eval_result->HasIsolateNodesPropagateCNodeFlag(); + } + if (collected_nodes.size() == 1) { + auto new_cnode = collected_nodes[0]; + MS_LOG(DEBUG) << "Build possible depend node for node: " << c_node->DebugString() + << ", depend node: " << new_cnode->DebugString(); + return new_cnode; + } + AbstractBasePtrList tuple_abstract; + std::transform(collected_nodes.cbegin(), collected_nodes.cend(), std::back_inserter(tuple_abstract), + [](const auto &collected_node) { return collected_node->abstract(); }); + auto make_tuple_ops = BuildValueNode(prim::kPrimMakeTuple, FromValueInside(prim::kPrimMakeTuple)); + collected_nodes.insert(collected_nodes.begin(), make_tuple_ops); + AnfNodePtr new_cnode = new_fg->NewCNode(collected_nodes); + new_cnode->set_abstract(std::make_shared(tuple_abstract)); + MS_LOG(DEBUG) << "Build possible depend node for node: " << c_node->DebugString() + << ", depend node: " << new_cnode->DebugString(2); + + return new_cnode; +} + +void FuncGraphSpecializer::ProcessIsolateNodes() { + // Process isolate nodes, take the isolate cnode as one because it may be forward to a new cnode. + for (const auto &node : func_graph_->isolate_nodes()) { + ScopeGuard scope_guard(node->scope()); + auto conf = MakeConfig(node); + // First of node_pair is the original node or the forwarded node, second is the replaced node. + const auto &node_pair = BuildReplacedNode(conf); + auto &replace_node = node_pair.first; + MS_EXCEPTION_IF_NULL(replace_node); + replace_node->set_abstract(GetEvaluatedValueWrap(conf)); + MS_LOG(DEBUG) << "BuildReplacedNode for isolate node, new_node: " << replace_node->DebugString() + << ", old node: " << node->DebugString(); + // Only the isolated node is forwarded, mark node as processed. Otherwise node is pushed to todo_ in + // BuildReplacednode and will be processed as normal node. + if (node != node_pair.first) { + (void)marked_.insert(node); + } + } +} + +std::pair FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) { MS_EXCEPTION_IF_NULL(conf); auto conf_iter = engine_->anfnode_config_map().find(conf); AnfNodeConfigPtr new_conf = conf; while (conf_iter != engine_->anfnode_config_map().end()) { - MS_LOG(DEBUG) << "Origin conf: graph(" << new_conf->node()->func_graph()->ToString() << ", node(" - << new_conf->node()->DebugString() << ")"; + MS_LOG(DEBUG) << "Origin conf: , node(" << new_conf->node()->DebugString() << ")"; new_conf = conf_iter->second; MS_EXCEPTION_IF_NULL(new_conf); - MS_LOG(DEBUG) << "Replaced conf: graph(" << conf->node()->func_graph()->ToString() << ", node(" - << conf->node()->DebugString() << ")"; - (void)ReplicateDisconnectedNode(new_conf->node()); + const auto &forward_node = new_conf->node(); + MS_LOG(DEBUG) << "Replaced conf: , node(" << forward_node->DebugString() << ")"; + const auto &replicated_forward_node = ReplicateDisconnectedNode(forward_node); + if (replicated_forward_node && replicated_forward_node->isa()) { + // The AnfNode in order_list can be: + // case 1: also in FuncGraphManager, so it can be got from nodes API of func_graph. it will + // be replaced in CloneOrderList in Cloner. + // case 2: AnfNode is not in FuncGraphManager which generated in Analyze phase, so it will not + // be cloned by normal clone API. + // 2.1: A forward node , the original node is in FuncGraphManager. The original node will + // be cloned in CloneOrderList in Cloner, and the replicated forward node will replace + // the replicated original node here. + // 2.2: an input of a forward node, such as Cast CNode generated in DoCast. It is also another + // original node to fowrad. + // 2.3: an input of an input of a forward node, but it's not an original node. Like the Cast CNode + // in MixedPrecisionCastHelper. + // For 2.2 and 2.3, we will put a placeholder in order list of replicated func_graph, refer to + // CloneOrderlist, and it will be replaced inside ReplicateDisconnectedNode. + // For 2.1 the following code will do the job, replace replicated origin cnode with the replicated + // forward one in the replicated func_graph. + const auto &origin_node = conf_iter->first->node(); + const auto &replicated_origin_node = GetReplicatedNode(origin_node); + if (replicated_origin_node != origin_node) { + MS_LOG(DEBUG) << "Replace replicated origin node in order list: " << replicated_origin_node->DebugString() + << ", with replicated forwarded node: " << replicated_forward_node->DebugString(); + replicated_forward_node->func_graph()->ReplaceInOrder(replicated_origin_node, replicated_forward_node); + } else { + MS_LOG(EXCEPTION) << "Origin node is not replicated in specialized func_graph, origin node: " + << origin_node->DebugString(); + } + } conf_iter = engine_->anfnode_config_map().find(new_conf); } todo_.push_back(new_conf->node()); @@ -294,7 +459,7 @@ AnfNodePtr FuncGraphSpecializer::BuildReplacedNode(const AnfNodeConfigPtr &conf) MS_LOG(DEBUG) << "Set repl: graph(nullptr), node(" << repl->DebugString() << ") to replace origin: " << new_conf->node()->DebugString(); } - return repl; + return std::make_pair(new_conf->node(), repl); } namespace { @@ -333,6 +498,13 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNode(const AnfNodePtr &node, co << ", abstract: " << abs->ToString(); } } + // Set the flag, so this MetaFuncGraph will be Re-AutoMonaded. + if (func->isa()) { + auto specialized_fg = GetValueNode(repl); + if (specialized_fg != nullptr && (argvals.size() > 1) && argvals[argvals.size() - 1]->isa()) { + specialized_fg->set_flag(mindspore::kFuncGraphFlagReAutoMonad, true); + } + } return repl; } @@ -425,7 +597,7 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &n MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString() << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args()); } - auto attrs = std::make_shared(); + static auto attrs = std::make_shared(); for (size_t i = 0; i < partial_closure->args().size(); i++) { auto old_node = cnode->input(i + 2); auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i], attrs); @@ -504,7 +676,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { std::vector args(new_inputs.begin() + 1, new_inputs.end()); // CNode(CNode(Partial, f, arg1), arg2, ...) --> CNode(f, arg1, arg2, ...) while (IsPrimitiveCNode(func, prim::kPrimPartial)) { - std::vector inputs = func->cast()->inputs(); + auto &inputs = func->cast()->inputs(); // First element is partial, second is func so arg is start from 2 (void)args.insert(args.begin(), inputs.begin() + 2, inputs.end()); func = inputs[1]; @@ -659,6 +831,22 @@ static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, c return prim; } +// Return true if this node can be replaced by value. +static bool CanReplaceByValue(const AnfNodePtr &node) { + auto cnode = dyn_cast(node); + if (cnode == nullptr || cnode->inputs().empty()) { + return true; + } + auto &input0 = cnode->inputs().at(0); + // Keep parameter not be replaced by value. + if (input0->isa()) { + return false; + } + // Keep 'depend' node not be replaced by value. + auto prim = GetValueNode(input0); + return !IsPrimitiveEquals(prim, prim::kPrimDepend); +} + AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin_node, const AbstractBasePtr &ival, const AttrValueMapPtr &attrs) { MS_EXCEPTION_IF_NULL(origin_node); @@ -699,8 +887,7 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin if (val->isa()) { return nullptr; } - // keep primitive 'depend' not to be optimized - if (IsPrimitiveCNode(origin_node, prim::kPrimDepend)) { + if (!CanReplaceByValue(origin_node)) { return nullptr; } return BuildValueNode(val, ival); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h index 2c08ea00ef..24892e117c 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.h @@ -98,6 +98,8 @@ class FuncGraphSpecializer : public std::enable_shared_from_thisnode; it may be a replicated forwared CNode in static analysis or just a - // replicated node. - AnfNodePtr BuildReplacedNode(const AnfNodeConfigPtr &conf); + // Build a replaceable node for iconf->node; it may be a replicated forward CNode in static analysis or just a + // replicated node. First of returned pair is the origin node or the forward cnode, second is the replaced node. + std::pair BuildReplacedNode(const AnfNodeConfigPtr &conf); + // Collect CNodes which have IsolateNodes that will be replaced by a ValuedNode. + AnfNodePtr CollectCNodeWithIsolateNodes(const CNodePtr &c_node, const EvalResultPtr &c_node_eval_result, + const FuncGraphPtr &new_fg); // Build a specialized node from given argvals; AnfNodePtr BuildSpecializedNode(const AnfNodePtr &node, const AbstractBasePtr &abs, const AbstractBasePtrList &argvals); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc index 119b6e19a8..f8c8ae2f3f 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.cc @@ -91,12 +91,6 @@ std::size_t AnfNodeConfigHasher::operator()(const AnfNodeConfigPtr conf) const { if (!conf->context()->IsDummyContext()) { hash_value = hash_combine(hash_value, std::hash{}(conf->context().get())); } - if (conf->context() != nullptr && conf->context()->func_graph() != nullptr) { - MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString() - << ", Graph: " << conf->context()->func_graph()->ToString() << " ### , hash value: " << hash_value; - } else { - MS_LOG(DEBUG) << "NodeConfigHasher Node: " << conf->node()->DebugString() << " ### , hash value: " << hash_value; - } return hash_value; } @@ -147,7 +141,7 @@ EvalResultPtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) { auto value = cache_.GetValue(conf); if (value != nullptr) { MS_LOG(DEBUG) << "Evaluate cache hit for NodeConfig: " << conf->ToString() << ", Value: " << value->abstract().get() - << ", " << value->abstract()->ToString(); + << ", " << value->abstract()->ToString() << ", flag: " << value->HasIsolateNodesPropagateCNodeFlag(); return value; } @@ -156,6 +150,9 @@ EvalResultPtr AnalysisEngine::GetEvaluatedValue(const AnfNodeConfigPtr &conf) { if (value == nullptr) { MS_LOG(EXCEPTION) << "Evaluate for NodeConfig " << conf->ToString() << " get nullptr"; } + MS_LOG(DEBUG) << "Evaluate node on demond for NodeConfig: " << conf->ToString() + << ", Value: " << value->abstract().get() << ", " << value->abstract()->ToString() + << ", flag: " << value->HasIsolateNodesPropagateCNodeFlag(); cache_.set_value(conf, value); return value; } @@ -180,16 +177,17 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { eval_result = std::make_shared(node->abstract(), std::make_shared()); } else if (node->isa()) { auto value_node = node->cast(); - eval_result = std::make_shared(EvalValueNode(value_node, conf), nullptr); + auto abstract = EvalValueNode(value_node, conf); + eval_result = std::make_shared(abstract, std::make_shared()); } else if (node->isa()) { - CheckNoStackInSameFuncGraph(conf); + // CheckNoStackInSameFuncGraph(conf); auto cnode = node->cast(); trace::TraceEvalCNodeEnter(conf); eval_result = EvalCNode(cnode, conf); trace::TraceEvalCNodeLeave(); } else { - MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, " << node->DebugString() - << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph") + MS_LOG(EXCEPTION) << "Illegal AnfNode for evaluating, node: " << node->DebugString() + << ", fg: " << (node->func_graph() != nullptr ? node->func_graph()->ToString() : "nullgraph") << ". NodeInfo: " << trace::GetDebugInfo(node->debug_info()); } @@ -200,7 +198,8 @@ EvalResultPtr AnalysisEngine::Eval(const AnfNodeConfigPtr &conf) { << " NodeInfo: " << trace::GetDebugInfo(node->debug_info()); } #endif - MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString(); + MS_LOG(DEBUG) << "End Eval NodeConfig " << conf->ToString() << ", res: " << eval_result->abstract()->ToString() + << ", flag: " << eval_result->HasIsolateNodesPropagateCNodeFlag(); return eval_result; } @@ -252,6 +251,20 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co return out; } +static bool CheckIsolateNodesPropagateFlag(const AbstractFunctionPtr &abs_func, const ConfigPtrList &conf_list) { + if (abs_func->HasIsolateNodesFlag()) { + MS_LOG(DEBUG) << "Propagate isolate nodes flag from: " << abs_func->ToString(); + return true; + } + auto flag = std::any_of(conf_list.cbegin(), conf_list.cend(), [](const ConfigPtr &conf) { + auto eval_result = conf->GetEvaluatedValue(); + MS_LOG(DEBUG) << "Propagate isolate nodes flag from: " << eval_result->abstract()->ToString() + << ", flag: " << eval_result->HasIsolateNodesPropagateCNodeFlag(); + return eval_result->HasIsolateNodesPropagateCNodeFlag(); + }); + return flag; +} + EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConfigPtr &conf) { MS_EXCEPTION_IF_NULL(conf); MS_EXCEPTION_IF_NULL(cnode); @@ -267,7 +280,8 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf AnfNodeConfigPtr func_conf = MakeConfig(func_node, context); MS_EXCEPTION_IF_NULL(func_conf); // Keep it in a local variable, otherwise smart pointer will free it. - AbstractBasePtr maybe_func = func_conf->GetEvaluatedValue()->abstract(); + auto maybe_func_eval_result = func_conf->GetEvaluatedValue(); + AbstractBasePtr maybe_func = maybe_func_eval_result->abstract(); if (maybe_func == nullptr) { MS_LOG(EXCEPTION) << "func_conf.GetEvaluatedValue() return null, func_conf: " << func_conf->ToString() << " NodeInfo: " << trace::GetDebugInfo(cnode->debug_info()); @@ -298,7 +312,23 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf }; func->Visit(build_evaluator); - return ExecuteEvaluators(infs, conf, args_conf_list); + auto eval_result = ExecuteEvaluators(infs, conf, args_conf_list); + auto flag = CheckIsolateNodesPropagateFlag(func, args_conf_list); + if (flag != eval_result->HasIsolateNodesPropagateCNodeFlag()) { + MS_LOG(DEBUG) << "Different propagate isolate nodes flag from: " << eval_result->abstract()->ToString() + << ", cnode flag: " << eval_result->HasIsolateNodesPropagateCNodeFlag() + << ", funcgraph flag: " << eval_result->HasIsolateNodesPropagateFuncGraphFlag() + << ", check flag:" << flag; + // This eval_result may be fetch from an Evaluator's cache based on args_spec_list equality. + // But args may be come from different CNode, so propagate flag is not same, + // a new copy of eval_result should be used. + auto new_eval_result = eval_result->Clone(); + // FuncGraph flag should be used for HOF call or used FuncGraph propagate. + flag = flag | new_eval_result->HasIsolateNodesPropagateFuncGraphFlag(); + new_eval_result->SetIsolateNodesPropagateCNodeFlag(flag); + eval_result = new_eval_result; + } + return eval_result; } EvalResultPtr AnalysisEngine::Execute(const AbstractFunctionPtr &func, const AbstractBasePtrList &args_spec_list) { @@ -534,6 +564,33 @@ EvaluatorPtr AnalysisEngine::GetEvaluatorFor(const AbstractFunctionPtr &func) { return tracked_eval; } +EvalResultPtr AnalysisEngine::ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) { + // Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor. + (void)anfnode_config_map_.emplace(orig_conf, new_conf); + MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString() + << ", to new_conf: " << new_conf->node()->DebugString(); + if (orig_conf->node()->isa()) { + auto old_cnode = orig_conf->node()->cast(); + MS_EXCEPTION_IF_NULL(old_cnode); + if (new_conf->node()->isa()) { + auto new_cnode = new_conf->node()->cast(); + MS_EXCEPTION_IF_NULL(new_cnode); + if (old_cnode->func_graph() == new_cnode->func_graph()) { + MS_LOG(DEBUG) << "Try to remove forward node from order list, forward node: " << new_cnode->ToString() + << ", as origin node should be in order list, origin_node: " << old_cnode->ToString(); + old_cnode->func_graph()->EraseUnusedNodeInOrder(new_cnode); + } else { + MS_LOG(EXCEPTION) << "Forward orig_node to different func_graph, old_node: " << old_cnode->DebugString() + << ", new_node: " << new_cnode->DebugString(); + } + } + } + forward_count_++; + auto res = GetEvaluatedValue(new_conf); + forward_count_--; + return res; +} + EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector &evaluators, const AnfNodeConfigPtr &out_conf, const ConfigPtrList &args_conf_list) { if (evaluators.size() == 1) { @@ -569,7 +626,7 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vectorfirst); + auto it_temp = std::find(evaluators.begin(), evaluators.end(), r_it->evaluator_); if (it_temp != evaluators.end()) { latest_entry = *it_temp; latest_entry_iter = r_it; @@ -585,20 +642,21 @@ EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector> undetermined_evals; + std::unordered_set undetermined_evals; for (auto r_it = eval_trace_.rbegin(); r_it != latest_entry_iter; r_it++) { undetermined_evals.insert(*r_it); } MS_LOG(DEBUG) << "undetermined_evals size(): " << undetermined_evals.size(); for (auto u_eval : undetermined_evals) { - MS_LOG(DEBUG) << u_eval.first->ToString() << "check undetermined."; - auto &alternate_evaluator = multi_poss_[u_eval.first]; + MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "check undetermined."; + auto &alternate_evaluator = multi_poss_[u_eval.evaluator_]; auto &eval_cache = alternate_evaluator->cache(); - if ((!undetermined_evals.count(std::make_pair(alternate_evaluator, args_spec_list))) && + const auto &alt_eval_args = EvaluatorArgs(alternate_evaluator, args_spec_list); + if ((!undetermined_evals.count(alt_eval_args)) && (((!continued_evals_.count(u_eval)) && (eval_cache->find(args_spec_list) != eval_cache->end())) || (eval_cache->find(args_spec_list) == eval_cache->end()))) { - MS_LOG(DEBUG) << u_eval.first->ToString() << "has undetermined."; + MS_LOG(DEBUG) << u_eval.evaluator_->ToString() << "has undetermined."; has_undetermined = true; break; } @@ -645,7 +703,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vectorToString(); // If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating. auto it = std::find(eval_trace_.rbegin(), eval_trace_.rend(), current_inf); @@ -663,13 +721,13 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vectorToString(); + MS_LOG(DEBUG) << "continued_evals_ add " << current_inf.evaluator_.get() << current_inf.evaluator_->ToString(); continued_evals_.insert(current_inf); continue; } // Try to travel the latest undetermined. - if (latest_entry != eval_trace_.rbegin()->first) { + if (latest_entry != eval_trace_.rbegin()->evaluator_) { MS_LOG(DEBUG) << "Direct Run Evaluator " << eval.get() << "----" << eval->ToString(); auto eval_result = latest_entry->Run(shared_from_this(), args_conf_list, out_conf); MS_EXCEPTION_IF_NULL(eval_result->abstract()); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h index 31d65c8e66..b7b004f99e 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/static_analysis.h @@ -27,6 +27,7 @@ #include #include #include +#include #ifdef DEBUG #include @@ -45,6 +46,9 @@ namespace abstract { using AttrValueMap = std::unordered_map; using AttrValueMapPtr = std::shared_ptr; +inline const int kIsolateNodesPropagateCNodeFlag = 1; +inline const int kIsolateNodesPropagateFuncGraphFlag = 2; + // the class to save evaluated result: abstract value and modified attribute class EvalResult : public Base { public: @@ -54,12 +58,46 @@ class EvalResult : public Base { AbstractBasePtr abstract() { return abstract_; } AttrValueMapPtr attribute() { return attribute_; } + std::shared_ptr Clone() const { + auto cloned = std::make_shared(abstract_, attribute_); + cloned->SetIsolateNodesPropagateCNodeFlag(HasIsolateNodesPropagateCNodeFlag()); + cloned->SetIsolateNodesPropagateFuncGraphFlag(HasIsolateNodesPropagateFuncGraphFlag()); + return cloned; + } + // The related AbstractBase is evaluated from CNode which input has isolate nodes. + // This flag is propagated to all user node. + // When a node A can be specialized to a ValueNode, we should check if that node A has this flag, + // if it has, then the original FuncGraph call should be depended, so it's side effect will not + // be lost. + bool HasIsolateNodesPropagateCNodeFlag() const { + auto iter = eval_attr_.find(kIsolateNodesPropagateCNodeFlag); + if (iter != eval_attr_.end()) { + return GetValue(iter->second); + } + return false; + } + void SetIsolateNodesPropagateCNodeFlag(bool flag) { eval_attr_[kIsolateNodesPropagateCNodeFlag] = MakeValue(flag); } + + // FuncGraph itself may not have IsoloateNodes, but the used FuncGraph or HOF call may have IsolateNodes; + bool HasIsolateNodesPropagateFuncGraphFlag() const { + auto iter = eval_attr_.find(kIsolateNodesPropagateFuncGraphFlag); + if (iter != eval_attr_.end()) { + return GetValue(iter->second); + } + return false; + } + void SetIsolateNodesPropagateFuncGraphFlag(bool flag) { + eval_attr_[kIsolateNodesPropagateFuncGraphFlag] = MakeValue(flag); + } + private: AbstractBasePtr abstract_; + // Attribute related to PrimEvaluator; AttrValueMapPtr attribute_; + std::unordered_map eval_attr_; }; - using EvalResultPtr = std::shared_ptr; + // Superclass for AnfNodeConfig and VirtualConfig. class Config : public Base { public: @@ -174,7 +212,6 @@ struct AnalysisResult { AnalysisContextPtr context; }; -using EvalTraceRevIter = std::list>::reverse_iterator; struct PartialAppHasher { std::size_t operator()(const std::pair &p) const { auto h1 = std::hash{}(p.first); @@ -222,16 +259,7 @@ class AnalysisEngine : public std::enable_shared_from_this { // Set the analysis result for orig to the result for new. // This sets an entry in anfnode_config_map from orig to new. - EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf) { - // Use anfnode_config_map_[orig_conf] = new_conf will require AnfNodeConfig provide copy constructor. - (void)anfnode_config_map_.emplace(orig_conf, new_conf); - MS_LOG(DEBUG) << "Forward orig_conf: " << orig_conf->node()->DebugString() - << ", to new_conf: " << new_conf->node()->DebugString(); - forward_count_++; - auto res = GetEvaluatedValue(new_conf); - forward_count_--; - return res; - } + EvalResultPtr ForwardConfig(const AnfNodeConfigPtr &orig_conf, const AnfNodeConfigPtr new_conf); const PrimEvaluatorMap &PrimConstructors() const { return prim_constructors_; } AnalysisCache cache_; @@ -253,6 +281,33 @@ class AnalysisEngine : public std::enable_shared_from_this { void CheckNoStackInSameFuncGraph(const AnfNodeConfigPtr &conf); private: + // Should compare Args based on value other than pointer; + struct EvaluatorArgs { + EvaluatorArgs(const EvaluatorPtr &eval, const AbstractBasePtrList &args) : evaluator_(eval), args_(args) {} + bool operator==(const EvaluatorArgs &other) const { + if (evaluator_ != other.evaluator_) { + return false; + } + if (AbstractBasePtrListDeepEqual(args_, other.args_)) { + return true; + } + return false; + } + bool operator!=(const EvaluatorArgs &other) { return !(*this == other); } + + EvaluatorPtr evaluator_; + AbstractBasePtrList args_; + }; + using EvalTraceRevIter = std::list::reverse_iterator; + struct EvaluatorArgsHasher { + std::size_t operator()(const EvaluatorArgs &eval_args) const { + return hash_combine(std::hash{}(eval_args.evaluator_), AbstractBasePtrListHash(eval_args.args_)); + } + }; + struct EvaluatorArgsEqual { + bool operator()(const EvaluatorArgs &lhs, const EvaluatorArgs &rhs) const { return lhs == rhs; } + }; + void SetUndeterminedFlag(const EvaluatorPtr &evaluator); EvaluatorPtr HandleNestedRecursion(const std::vector &evaluators, const EvaluatorPtr &eval, const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it, @@ -266,9 +321,9 @@ class AnalysisEngine : public std::enable_shared_from_this { constructors_app_; AnfNodeConfigMap anfnode_config_map_; // Use a list to trace multiple evaluators. - std::list> eval_trace_; + std::list eval_trace_; std::map multi_poss_; - std::set> continued_evals_; + std::unordered_set continued_evals_; AnalysisContextPtr Run(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, const ConfigPtrList &args_conf_list); diff --git a/mindspore/ccsrc/pipeline/jit/validator.cc b/mindspore/ccsrc/pipeline/jit/validator.cc index ad2b127fdd..9251798fd1 100644 --- a/mindspore/ccsrc/pipeline/jit/validator.cc +++ b/mindspore/ccsrc/pipeline/jit/validator.cc @@ -109,6 +109,11 @@ void ValidateAbstract(const AnfNodePtr &node) { return; } + // UMonad or IOMonad + if (ptrBase->isa()) { + return; + } + // Other types show exception MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); } diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index c4fdb0c92d..3df381eb1a 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -42,6 +42,7 @@ #include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/parse/resolve.h" #include "pipeline/jit/static_analysis/prim.h" +#include "pipeline/jit/static_analysis/auto_monad.h" #include "backend/session/session_factory.h" #include "backend/optimizer/pass/const_input_to_attr_registry.h" #include "backend/optimizer/common/helper.h" @@ -691,7 +692,13 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v const auto &signature = prim->signatures(); auto sig_size = signature.size(); auto size = op_exec_info->op_inputs.size(); - // ignore signature for cast op + + // ignore monad signature + for (auto sig : signature) { + if (sig.default_value != nullptr && sig.default_value->isa()) { + --sig_size; + } + } if (sig_size > 0 && sig_size != size) { MS_EXCEPTION(ValueError) << op_exec_info->op_name << " inputs size " << size << " does not match the requires " << "inputs size " << sig_size; @@ -757,7 +764,7 @@ AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::v CNodePtr cnode = nullptr; if (need_construct_graph()) { MS_EXCEPTION_IF_NULL(curr_g_); - cnode = curr_g_->NewCNode(inputs); + cnode = curr_g_->NewCNodeInOrder(inputs); MS_LOG(DEBUG) << "Make CNode for " << op_exec_info->op_name << " new cnode is " << cnode->DebugString(4); } return cnode; @@ -2364,9 +2371,11 @@ FuncGraphPtr PynativeExecutor::MakeGradGraph(const py::object &cell, const FuncG (void)bprop_graph->transforms().emplace(std::make_pair("primal", FuncGraphTransform(g))); } } - DumpGraphIR("fg.ir", g); auto is_top = IsTopGraph(cell_id); MS_LOG(DEBUG) << "Grad top cell " << is_top; + // Before make grad graph, we need to run auto-monad on forward graph, + // so that side effects in forward graph can be handled in grad graph. + (void)pipeline::AutoMonad(g); set_need_replace_forward(!IsNestedGrad()); // Obtain grad graph auto newfg = ad::Grad(g, r, is_top); diff --git a/mindspore/ccsrc/profiler/device/gpu/gpu_profiling_utils.cc b/mindspore/ccsrc/profiler/device/gpu/gpu_profiling_utils.cc index eec4d23dfa..efacdb7005 100644 --- a/mindspore/ccsrc/profiler/device/gpu/gpu_profiling_utils.cc +++ b/mindspore/ccsrc/profiler/device/gpu/gpu_profiling_utils.cc @@ -120,7 +120,8 @@ void ProfilingUtils::SetTraceBpEnd(const std::vector &cnode_exec_order if (iter != cnode_exec_order.rend()) { // store communication op input nodes' name std::set ar_input_node_names; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(*iter); ++i) { + size_t input_num = AnfAlgo::GetInputTensorNum(*iter); + for (size_t i = 0; i < input_num; ++i) { auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(*iter, i); auto input_node = input_node_with_index.first; ar_input_node_names.insert(input_node->fullname_with_scope()); diff --git a/mindspore/ccsrc/ps/parameter_server.h b/mindspore/ccsrc/ps/parameter_server.h index 2a790f7332..e28fcb4514 100644 --- a/mindspore/ccsrc/ps/parameter_server.h +++ b/mindspore/ccsrc/ps/parameter_server.h @@ -775,6 +775,10 @@ void ParameterServer::GetEmbeddingTableParamPtr() { std::string cnode_name = AnfAlgo::GetCNodeName(cnode); if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName || cnode_name == kSparseGatherV2OpName) { auto embedding_table = AnfAlgo::GetInputNode(cnode, 0); + if (IsPrimitiveCNode(embedding_table, prim::kPrimLoad)) { + auto embedding_cnode = embedding_table->cast(); + embedding_table = AnfAlgo::GetInputNode(embedding_cnode, 0); + } MS_EXCEPTION_IF_NULL(embedding_table); if (embedding_table->isa()) { MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count; diff --git a/mindspore/ccsrc/pybind_api/ir/value_py.cc b/mindspore/ccsrc/pybind_api/ir/value_py.cc index 1d80c74c4d..e0d49e7719 100644 --- a/mindspore/ccsrc/pybind_api/ir/value_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/value_py.cc @@ -21,9 +21,11 @@ #include "abstract/abstract_value.h" namespace mindspore { -// Define python 'RefKey' class. +// Define python class for values. REGISTER_PYBIND_DEFINE( - RefKey, ([](const py::module *m) { + Values, ([](const py::module *m) { (void)py::class_>(*m, "RefKey").def(py::init(), py::arg("tag")); + (void)py::class_>(*m, "UMonad").def(py::init()); + (void)py::class_>(*m, "IOMonad").def(py::init()); })); } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index 9c669fd510..c94f0f71dc 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -524,6 +524,9 @@ bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size const void *host_ptr) const { MS_LOG(INFO) << "SyncHostToDevice, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; + if (type_id_ > kMonadTypeBegin && type_id_ < kMonadTypeEnd) { + return true; + } SyncStream(); bool sync_ok = false; std::vector host_shape; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 83edaa3621..13556fedf7 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -29,7 +29,6 @@ #include "runtime/device/ascend/profiling/profiling_manager.h" #include "common/trans.h" #include "runtime/context.h" -#include "runtime/device/ascend/ascend_label_assign.h" #include "runtime/device/ascend/ascend_stream_assign.h" #include "framework/ge_runtime/model_runner.h" #include "runtime/device/ascend/tasksink/task_generator.h" @@ -425,7 +424,6 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { } AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); AscendResourceMng &resource_manager = AscendResourceMng::GetInstance(); - AscendLabelAssign &label_assign_instance = AscendLabelAssign::GetInstance(); // the streams' flag not HEAD_STREAM std::vector wait_active_stream_list; assign_instance.GetWaitStreams(&wait_active_stream_list); @@ -433,14 +431,13 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { assign_instance.GetHcomStreams(&force_copy_stream_list); MS_LOG(INFO) << "Call DavinciModel total stream num:" << resource_manager.get_cur_stream_num() << ", total event num:" << resource_manager.get_cur_event_num() - << ", total label num:" << label_assign_instance.GetLabelNum(NOT_NULL(graph)) + << ", total label num:" << graph->label_num() << ", wait_active_stream_list size:" << wait_active_stream_list.size() << ", force_copy_stream_list size:" << force_copy_stream_list.size(); std::vector> empty_list; auto model = std::make_shared( task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, - 0, 0, 0, 0, 0, resource_manager.get_cur_stream_num(), label_assign_instance.GetLabelNum(NOT_NULL(graph)), - resource_manager.get_cur_event_num(), 0); + 0, 0, 0, 0, 0, resource_manager.get_cur_stream_num(), graph->label_num(), resource_manager.get_cur_event_num(), 0); auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); if (!ret.second) { MS_LOG(EXCEPTION) << "Duplicate GraphId! Please check in ascend_session."; diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc index a07fa0cc04..28d68cd8e2 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_ascend.cc @@ -80,7 +80,8 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) { bool is_init = false; bool need_change_nd = false; bool is_5d_input = false; - for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) { + size_t input_num = AnfAlgo::GetInputTensorNum(cnode); + for (size_t index = 0; index < input_num; ++index) { auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index); if (AnfAlgo::IsFeatureMapInput(cnode, index) && kHWSpecialFormatSet.find(pre_output_format) != kHWSpecialFormatSet.end()) { @@ -140,7 +141,8 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons MS_LOG(EXCEPTION) << "Out of range cur_kernel info_match_counts " << MATCH_COUNT_PRIORITY_END; } auto pri_match_format = GetPriorityMatchFormat(kernel_node); - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { auto input_anf_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(kernel_node, input_index), 0).first; MS_EXCEPTION_IF_NULL(input_anf_node); // we do not take ValueNode into consideration in graph kernel. @@ -168,7 +170,8 @@ void UpdateCurMatchCounts(const kernel::KernelBuildInfo &kernel_build_info, cons } } - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { // cal count of same output dtype between abstract and kernel info if (kernel_build_info.GetOutputDeviceType(output_index) == AnfAlgo::GetOutputInferDataType(kernel_node, output_index)) { @@ -327,14 +330,14 @@ void SetCastAndWeightFormat(const CNodePtr &kernel_node) { if (!AnfAlgo::HasNodeAttr(kAttrPynativeNextIndex, kernel_node) || !AnfAlgo::HasNodeAttr(kAttrPynativeNextOpName, kernel_node)) { MS_LOG(EXCEPTION) << "The node [" << kernel_node->DebugString() << "] attr of " << kAttrPynativeNextIndex << " or " - << kAttrPynativeNextOpName << " has been not setted yet!" + << kAttrPynativeNextOpName << " has not been set yet!" << " trace: " << trace::DumpSourceLines(kernel_node); } auto next_index = AnfAlgo::GetNodeAttr(kernel_node, kAttrPynativeNextIndex); auto next_op_name = AnfAlgo::GetNodeAttr(kernel_node, kAttrPynativeNextOpName); auto iter = kNextOpFormatList.find(next_op_name); if (iter == kNextOpFormatList.end()) { - MS_LOG(INFO) << "The op name " << next_op_name << "has been not setted in the next op map "; + MS_LOG(INFO) << "The op name " << next_op_name << "has not been set in the next op map "; return; } if (iter->second.size() < next_index) { @@ -405,7 +408,8 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) { MS_EXCEPTION_IF_NULL(kernel_node); auto selected_kernel_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node); MS_EXCEPTION_IF_NULL(selected_kernel_info); - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { auto input_kernel_node = AnfAlgo::GetInputNode(kernel_node, input_index); MS_EXCEPTION_IF_NULL(input_kernel_node); auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); @@ -438,7 +442,7 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, bool precision_reduce = false; std::shared_ptr selected_kernel_info = nullptr; // Matched kernel info - // Filter kernel info matched with me infered type + // Filter kernel info matched with me inferred type auto filtered_kernel_info_list = FilteredKernelInfoByDtype(kernel_node, kernel_info_list); if (!filtered_kernel_info_list.empty()) { selected_kernel_info = ChooseMatchedKernelInfo(kernel_node, filtered_kernel_info_list); diff --git a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc index 4329a43e33..82b0115ec4 100644 --- a/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc +++ b/mindspore/ccsrc/runtime/device/ascend/profiling/profiling_utils.cc @@ -161,7 +161,8 @@ std::string ProfilingUtils::GetTraceBpEnd(const std::vector &cnode_exe if (AnfAlgo::IsCommunicationOp(*iter)) { // store communication op input nodes' name std::set ar_input_node_names; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(*iter); ++i) { + size_t input_num = AnfAlgo::GetInputTensorNum(*iter); + for (size_t i = 0; i < input_num; ++i) { auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(*iter, i); auto input_node = input_node_with_index.first; ar_input_node_names.insert(input_node->fullname_with_scope()); diff --git a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc index 922c4b10d6..86f7454eb2 100644 --- a/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc +++ b/mindspore/ccsrc/runtime/device/ascend/tasksink/task_generator.cc @@ -142,7 +142,8 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i } if (op_name != kAtomicAddrCleanOpName) { - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(anf_node_ptr); ++i) { + size_t input_num = AnfAlgo::GetInputTensorNum(anf_node_ptr); + for (size_t i = 0; i < input_num; ++i) { if (op_name == kDynamicRNNOpName && i == 3) { continue; } @@ -177,12 +178,16 @@ bool TaskGenerator::LaunchKernel(const CNodePtr &anf_node_ptr, uint32_t stream_i kernel_inputs.push_back(input); } - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(anf_node_ptr); ++i) { - auto it = AnfAlgo::GetOutputAddr(anf_node_ptr, i); - AddressPtr output = std::make_shared
(); - output->addr = it->ptr_; - output->size = it->size_; - kernel_outputs.push_back(output); + // No kernel output if output of the cnode is monad, such as LabelSwitch. + if (!HasAbstractMonad(anf_node_ptr)) { + size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node_ptr); + for (size_t i = 0; i < output_num; ++i) { + auto it = AnfAlgo::GetOutputAddr(anf_node_ptr, i); + AddressPtr output = std::make_shared
(); + output->addr = it->ptr_; + output->size = it->size_; + kernel_outputs.push_back(output); + } } for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) { diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc index c6815a343e..20f291a32f 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -286,7 +286,7 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker size_t input_idx = 0; for (auto &item : input_nodes) { MS_EXCEPTION_IF_NULL(item); - if (item->isa()) { + if (item->isa() && !HasAbstractMonad(item)) { auto address = AnfAlgo::GetMutableOutputAddr(item, 0); auto tensor = inputs[input_idx]; auto tensor_address = tensor->device_address(); diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index dbdd17c915..8579dcbbe0 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -921,7 +921,8 @@ bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &k MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(kernel_inputs); MS_EXCEPTION_IF_NULL(mem_reuse_util_); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel); + for (size_t i = 0; i < input_num; ++i) { DeviceAddressPtr device_address; if (mem_reuse_util_->is_all_nop_node()) { // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. @@ -1099,7 +1100,8 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel) return; } // Free the input of kernel by reference count. - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel); + for (size_t i = 0; i < input_num; ++i) { if (AnfAlgo::IsInplaceNode(kernel, "aggregate")) { auto primitive = AnfAlgo::GetCNodePrimitive(kernel); auto index = GetValue(primitive->GetAttr("aggregate_input_index")); @@ -1137,7 +1139,8 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel) } } // Free the output of kernel, if output has no reference. - for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel); + for (size_t i = 0; i < output_num; ++i) { auto kernel_ref_count_ptr = mem_reuse_util_->GetRef(cnode, i); if (kernel_ref_count_ptr == nullptr) { continue; diff --git a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc index 21a08868bc..82cb38aa19 100644 --- a/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc +++ b/mindspore/ccsrc/runtime/device/gpu/kernel_info_setter.cc @@ -138,7 +138,8 @@ bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptrinput(input_index + 1); MS_EXCEPTION_IF_NULL(input_kernel_node); auto input_with_index = AnfAlgo::VisitKernel(input_kernel_node, 0); @@ -200,7 +201,8 @@ bool IsNeedProcessFormatInfo(const CNodePtr &kernel_node, const std::vectorsecond.first; // If input position is empty, then insert all the input positions, because the input numbers of this op are variable. if (inputs_format_position.size() == 0) { - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); input_index++) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t input_index = 0; input_index < input_num; input_index++) { inputs_format_position.push_back(input_index); } } @@ -226,7 +228,8 @@ void UpdateKernelFormatInfo(const CNodePtr &kernel_node, const std::vectorsecond.first; // If input position is empty, then insert all the input positions, because the input numbers of this op are variable. if (inputs_format_position.size() == 0) { - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); input_index++) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t input_index = 0; input_index < input_num; input_index++) { inputs_format_position.push_back(input_index); } } @@ -370,13 +373,15 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) { } std::vector inputs_format; std::vector inputs_type; - for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { inputs_format.emplace_back(kOpFormat_DEFAULT); inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); } std::vector outputs_format; std::vector outputs_type; - for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { outputs_format.emplace_back(kOpFormat_DEFAULT); outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); } diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 26f3dfe26b..d64967cea7 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -115,7 +115,8 @@ void KernelRuntime::RunOpClearMemory(const session::KernelGraph *graph) { for (const auto &cnode : graph->execution_order()) { MS_EXCEPTION_IF_NULL(cnode); // clear output memory resource - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); + for (size_t index = 0; index < output_num; ++index) { AnfAlgo::SetOutputAddr(nullptr, index, cnode.get()); } // clear workspace memory resource @@ -524,7 +525,8 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP MS_EXCEPTION_IF_NULL(mem_manager_); size_t total_size = 0; std::vector> addr_size; - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(node); ++i) { + size_t input_num = AnfAlgo::GetInputTensorNum(node); + for (size_t i = 0; i < input_num; ++i) { auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(node, i); auto input_node = input_node_with_index.first; DeviceAddressPtr address = nullptr; @@ -825,7 +827,8 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); auto visit_nop_node = (ms_context->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode); - for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel); + for (size_t i = 0; i < input_num; ++i) { auto op_name = AnfAlgo::GetCNodeName(cnode); constexpr auto none_placeholder_index = 3; if (op_name == kDynamicRNNOpName && i == none_placeholder_index) { @@ -993,7 +996,8 @@ void KernelRuntime::ClearOutputAddress(const std::vector &inputs, if (parameter->used_graph_count() != 0) { continue; } - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(input_node); ++index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(input_node); + for (size_t index = 0; index < output_num; ++index) { if (!AnfAlgo::OutputAddrExist(input_node, index)) { continue; } @@ -1009,7 +1013,8 @@ void KernelRuntime::ClearOutputAddress(const std::vector &inputs, } // clear cnode output address. for (const auto &cnode : execution_order) { - for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) { + size_t output_num = AnfAlgo::GetOutputTensorNum(cnode); + for (size_t index = 0; index < output_num; ++index) { if (!AnfAlgo::OutputAddrExist(cnode, index)) { continue; } diff --git a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc index cadc5e2f7c..35c10fe2a2 100644 --- a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc @@ -353,7 +353,11 @@ std::string IrExportBuilder::GetOpTypeName(const AnfNodePtr &node) { void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePtr &shape, mind_ir::AttributeProto *const attr_proto, std::string *const seq_string) { - if (type->isa() && seq_string != nullptr) { + if (seq_string == nullptr) { + MS_LOG(EXCEPTION) << "seq_string is nullptr."; + } + + if (type->isa()) { *seq_string += "Tuple["; auto elements = type->cast()->elements(); auto tuple_shape = shape->cast()->shape(); @@ -361,13 +365,13 @@ void IrExportBuilder::SetShapeToNodeProto(const TypePtr &type, const BaseShapePt SetShapeToNodeProto(elements[i], tuple_shape[i], attr_proto, seq_string); } *seq_string += "],"; - } else if (type->isa() && shape->isa() && seq_string != nullptr) { + } else if (type->isa() && shape->isa()) { string shape_name = "shape" + std::to_string(GetTupleIndex()); *seq_string += shape_name + ","; mind_ir::TensorProto *tensor_proto = attr_proto->add_tensors(); tensor_proto->set_name(shape_name); SetTensorProto(type, shape, tensor_proto); - } else if ((type->isa() || type->isa()) && seq_string != nullptr) { + } else if (type->isa() || type->isa() || type->isa() || type->isa()) { *seq_string += type->type_name() + ","; } else { MS_LOG(EXCEPTION) << "Type of cnode need to be supported: " << type->type_name(); @@ -553,6 +557,14 @@ void IrExportBuilder::SetValueToAttributeProto(const ValuePtr &value, mind_ir::A } else if (value->isa()) { attr_proto->set_ref_attr_name("none"); MS_LOG(DEBUG) << "Attr string: " << value->type_name(); + } else if (value->isa()) { + if (value->isa()) { + attr_proto->set_ref_attr_name("Monad:UMonad"); + } else if (value->isa()) { + attr_proto->set_ref_attr_name("Monad:IOMonad"); + } else { + MS_LOG(EXCEPTION) << "Unsupported Monad type: " << value->type_name(); + } } else { MS_LOG(EXCEPTION) << "Unsupported type: " << value->type_name(); } diff --git a/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc b/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc index 69e41dd840..f6d79a8e8d 100644 --- a/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/onnx_exporter.cc @@ -45,6 +45,15 @@ struct OpMergedInfo { using GenAttrFuncType = std::function; +static AnfNodePtr GetRealInput(const AnfNodePtr &origin_input) { + AnfNodePtr input = origin_input; + while (IsPrimitiveCNode(input, prim::kPrimDepend) || IsPrimitiveCNode(input, prim::kPrimLoad)) { + // Skip Depend and Load cnodes. + input = input->cast()->inputs().at(1); + } + return input; +} + template void SetAttrValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { @@ -251,7 +260,7 @@ OPERATOR_ONNX_CONVERT_DEFINE( .Attr("strides", "strides", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<2>)) OPERATOR_ONNX_CONVERT_DEFINE(Gather, Gather, OpNameInfo()) -OPERATOR_ONNX_CONVERT_DEFINE(make_tuple, SequenceConstruct, OpNameInfo()) +OPERATOR_ONNX_CONVERT_DEFINE(MakeTuple, SequenceConstruct, OpNameInfo()) OPERATOR_ONNX_CONVERT_DEFINE(Concat, Concat, OpNameInfo()) OPERATOR_ONNX_CONVERT_DEFINE(RealDiv, Div, OpNameInfo()) OPERATOR_ONNX_CONVERT_DEFINE(ReduceSum, ReduceSum, OpNameInfo()) @@ -278,7 +287,7 @@ void RegisterOpConverters(const std::function &fn) { fn(OP_CONVERT_FUNCTION_NAME(BatchNorm)()); fn(OP_CONVERT_FUNCTION_NAME(MatMul)()); - fn(OP_CONVERT_FUNCTION_NAME(make_tuple)()); + fn(OP_CONVERT_FUNCTION_NAME(MakeTuple)()); fn(OP_CONVERT_FUNCTION_NAME(Concat)()); fn(OP_CONVERT_FUNCTION_NAME(RealDiv)()); fn(OP_CONVERT_FUNCTION_NAME(BiasAdd)()); @@ -529,7 +538,12 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vecto // if the key `input` does not exist, just create a new one op_merged_infos[cnode].referred_count += 1; } - for (auto &input : cnode->inputs()) { + for (auto &orig_input : cnode->inputs()) { + if (HasAbstractMonad(orig_input)) { + // Skip monad inputs. + continue; + } + auto input = GetRealInput(orig_input); if (!input->isa()) { continue; } @@ -987,7 +1001,9 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n std::vector op_inputs; // first process node input 1,2,..., since when node input is a ValueNode, here need to create a Constant Operator for (size_t i = 1; i < inputs.size(); i++) { - op_inputs.push_back(inputs[i]); + if (!HasAbstractMonad(inputs[i])) { + op_inputs.push_back(inputs[i]); + } } auto op_value = dyn_cast(op); if (op_value == nullptr) { @@ -998,7 +1014,9 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &n MS_LOG(EXCEPTION) << "Need to support node op type " << op_value->value()->type_name(); } - (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); + if (!IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) { + (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); + } } size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr & /*func_graph*/, std::map *node_map_ptr, @@ -1103,12 +1121,13 @@ void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNode SetValueInfoType(arg, output_proto, false); } -std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &node, std::map *node_map_ptr, +std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &orig_node, std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { + auto node = GetRealInput(orig_node); if (node->isa()) { auto iter = node_map_ptr->find(node); if (iter == node_map_ptr->end()) { - MS_LOG(EXCEPTION) << "Can not find node '" << node->ToString() << "' in node_map"; + MS_LOG(EXCEPTION) << "Can not find node '" << node->DebugString() << "' in node_map"; } return std::to_string(iter->second); } diff --git a/mindspore/ccsrc/transform/graph_ir/convert.cc b/mindspore/ccsrc/transform/graph_ir/convert.cc index a41764cbd8..b64ef4a89d 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.cc +++ b/mindspore/ccsrc/transform/graph_ir/convert.cc @@ -221,7 +221,7 @@ void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std if (IsValueNode(it)) { auto symbolic = GetValueNode(it); auto name = std::static_pointer_cast(symbolic->node())->name(); - auto iter = vars_.find(name); // get correspoding varaible op + auto iter = vars_.find(name); // get corresponding variable op if (iter != vars_.end()) { op_cache_[it.get()] = iter->second; // #ifdef DRAW_GE_GRAPH @@ -232,7 +232,7 @@ void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std } else if (IsValueNode(it)) { auto refkey = GetValueNode(it); auto name = refkey->tag(); - auto iter = vars_.find(name); // get correspoding varaible op + auto iter = vars_.find(name); // get corresponding variable op if (iter != vars_.end()) { op_cache_[it.get()] = iter->second; compute_sout_ << op_draw_name_[params_[name].get()] << " -> " << op_draw_name_[it.get()] @@ -324,7 +324,7 @@ void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) { auto const_op_desc = TransformUtil::GetGeTensorDesc(it.second->shape_c(), it.second->data_type(), kOpFormat_NCHW); if (const_op_desc == nullptr) { - MS_LOG(ERROR) << "Create variable " << name << " ouptut descriptor failed!"; + MS_LOG(ERROR) << "Create variable " << name << " output descriptor failed!"; continue; } (void)std::static_pointer_cast(const_op)->update_output_desc_y(*const_op_desc); @@ -337,7 +337,7 @@ void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) { // create tensor descriptor for output descriptor auto desc = TransformUtil::GetGeTensorDesc(it.second->shape_c(), it.second->data_type(), kOpFormat_NCHW); if (desc == nullptr) { - MS_LOG(ERROR) << "Create variable " << name << " ouptut descriptor failed!"; + MS_LOG(ERROR) << "Create variable " << name << " output descriptor failed!"; continue; } @@ -484,7 +484,7 @@ DfGraphConvertor &DfGraphConvertor::GenerateBroadcastGraph(const TensorOrderMap // create tensor descriptor for output descriptor auto desc = TransformUtil::GetGeTensorDesc(shape_ge, tensor->data_type(), kOpFormat_NCHW); if (desc == nullptr) { - MS_LOG(ERROR) << "Create variable " << name << " ouptut descriptor failed!"; + MS_LOG(ERROR) << "Create variable " << name << " output descriptor failed!"; continue; } @@ -540,8 +540,10 @@ DfGraphConvertor &DfGraphConvertor::ConvertAllNode() { compute_sout_ << "digraph {" << endl; init_sout_.clear(); init_sout_ << "digraph {" << endl; +#if (defined ENABLE_GE) checkpoint_sout_.clear(); checkpoint_sout_ << "digraph {" << endl; +#endif restore_checkpoint_sout_.clear(); restore_checkpoint_sout_ << "digraph {" << endl; @@ -627,7 +629,7 @@ void DfGraphConvertor::TraceOutput(const AnfNodePtr node) { for (unsigned int i = 1; i < c->inputs().size(); i++) { TraceOutput(c->input(i)); } - } else if (name == "Depend") { + } else if (name == prim::kPrimDepend->name()) { if (c->inputs().size() < 3) { // "Depend" primitive have 3 inputs MS_LOG(EXCEPTION) << "length of inputs is " << c->inputs().size() << ", which is less than 3"; } @@ -644,7 +646,7 @@ void DfGraphConvertor::TraceOutput(const AnfNodePtr node) { if (item != out_handle_cache_.end()) { index = item->second.out; } else { - MS_LOG(WARNING) << "Can't get operater: " << anf_out->fullname_with_scope() << " 's output item"; + MS_LOG(WARNING) << "Can't get operator: " << anf_out->fullname_with_scope() << " 's output item"; } } MS_LOG(INFO) << "Add graph output: " << anf_out->fullname_with_scope() << ":" << index; @@ -744,7 +746,6 @@ void DfGraphConvertor::GetCaseNodeInput(const CNodePtr node, const CNodePtr inpu for (size_t i = 1; i < node->inputs().size(); i++) { case_inputs.emplace_back(node->input(i)); } - std::shared_ptr> branches = std::make_shared>(); auto bnode = input_node->input(2)->cast(); for (size_t i = 1; i < bnode->inputs().size(); i++) { @@ -768,12 +769,12 @@ void DfGraphConvertor::GetCaseNodeInput(const CNodePtr node, const CNodePtr inpu auto item = case_inputs[i]; auto op = Convert(item); if (op != nullptr) { - tuple_items->emplace_back(OutHandler(op, "")); + tuple_items->emplace_back(OutHandler(op, "", item)); } else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) { tuple_items->push_back(out_handle_cache_[item.get()]); } else { - MS_LOG(WARNING) << "This anf node is not supported as a case input: " << item->ToString(); - continue; + MS_LOG(DEBUG) << "Add an empty out handler: " << item->ToString(); + tuple_items->push_back(OutHandler()); } } @@ -785,6 +786,23 @@ void DfGraphConvertor::GetCaseNodeInput(const CNodePtr node, const CNodePtr inpu case_input_handle_cache_[node.get()] = case_input_items; } +void DfGraphConvertor::UpdateTupleOutCache() { + for (auto &it : tuple_out_handle_cache_) { + std::size_t len = it.second->size(); + for (std::size_t i = 0; i < len; i++) { + OutHandler handle = (*it.second)[i]; + if (handle.op == nullptr) { + continue; + } + string name = handle.op->GetName(); + if (vars_.count(name) && (vars_[name] != nullptr)) { + (*it.second)[i] = OutHandler(vars_[name], handle.out, handle.node); + MS_LOG(INFO) << "update tuple_out_handle_cache_ " << name; + } + } + } +} + DfGraphConvertor &DfGraphConvertor::BuildGraph() { SetupDatasetIterGetNextNode(dataset_iter_getnext_); @@ -803,25 +821,10 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() { } // update tuple_out_handle_cache_ - for (auto it : tuple_out_handle_cache_) { - std::size_t len = it.second->size(); - for (std::size_t i = 0; i < len; i++) { - OutHandler handle = (*it.second)[i]; - if (handle.op) { - string name = handle.op->GetName(); - if (vars_.count(name)) { - OperatorPtr new_op = vars_[name]; - if (new_op != nullptr) { - MS_LOG(INFO) << "update tuple_out_handle_cache_ " << name; - (*it.second)[i] = OutHandler(new_op, handle.out); - } - } - } - } - } + UpdateTupleOutCache(); - // set up dependices - MS_LOG(DEBUG) << "set up dependices"; + // set up dependencies + MS_LOG(DEBUG) << "set up dependencies"; nodes = ::mindspore::TopoSort(anf_graph_->get_return()); for (auto &it : nodes) { SetNodeInput(it); @@ -947,7 +950,195 @@ DfGraphPtr DfGraphConvertor::GetSaveCheckpointGraph() { return save_ckp_graph_; DfGraphPtr DfGraphConvertor::GetBroadcastGraph() { return broadcast_graph_; } -void DfGraphConvertor::SetOpControlInput(const AnfNodePtr node) { +bool DfGraphConvertor::IsSourceEdgeNode(const AnfNodePtr &node) { + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + if (!IsCustomCNode(cnode)) { + std::string name = GetCNodeTargetFuncName(cnode); + if (name.empty()) { + return false; + } + + // Ignore apply node Depend, UpdateState, make_tuple. make_tuple in ge pipeline. + if ((name == prim::kPrimDepend->name()) || (name == prim::kPrimUpdateState->name()) || + (name == prim::kPrimReturn->name()) || (name == prim::kPrimMakeTuple->name())) { + return false; + } + } + // Load and other normal primitives which contain monad node. + auto has_monad = std::any_of(cnode->inputs().begin(), cnode->inputs().end(), + [](const AnfNodePtr &node) -> bool { return HasAbstractMonad(node); }); + if (has_monad) { + return true; + } + + // primitive with make_tuple as input + for (auto &input : cnode->inputs()) { + if (IsPrimitiveCNode(input, prim::kPrimMakeTuple)) { + auto tuple = input->cast(); + auto ret = std::any_of(tuple->inputs().begin(), tuple->inputs().end(), + [](const AnfNodePtr &node) -> bool { return HasAbstractMonad(node); }); + if (ret) { + return true; + } + } + } + + return false; +} + +bool DfGraphConvertor::IsControlEdgeNode(const AnfNodePtr &node) { + if (!node->isa()) { + return false; + } + auto cnode = node->cast(); + if (!IsCustomCNode(cnode)) { + std::string name = GetCNodeTargetFuncName(cnode); + if (name.empty()) { + return false; + } + + // Ignore apply node of Load, Depend, UpdateState, make_tuple, return + if ((name == prim::kPrimLoad->name()) || (name == prim::kPrimDepend->name()) || + (name == prim::kPrimUpdateState->name()) || (name == prim::kPrimMakeTuple->name()) || + (name == prim::kPrimReturn->name())) { + return false; + } + } + return true; +} + +OperatorPtr DfGraphConvertor::ToOperatorPtr(const AnfNodePtr &node) { + auto op = Convert(GetRealOpNode(node)); + if (op == nullptr) { + MS_LOG(ERROR) << "Convert control depend node to operator failed, " << node->ToString(); + error_ = FAILED; + return nullptr; + } + return op; +} + +void DfGraphConvertor::AddEdgeToCache(const AnfNodePtr &src, const AnfNodePtr &dest) { + auto item = monad_control_edge_cache_.find(src); + if (item == monad_control_edge_cache_.end()) { + monad_control_edge_cache_[src] = std::set{dest}; + } else { + item->second.insert(dest); + } +} + +void DfGraphConvertor::AddEdgeForLoad(const AnfNodePtr &node) { + auto func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + func_graph->set_manager(mng); + } + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + auto &users = manager->node_users()[node]; + std::shared_ptr> src_node_list = std::make_shared>(); + std::shared_ptr> dst_node_list = std::make_shared>(); + for (const auto &iter : users) { + auto user_node = iter.first; + auto name = GetCNodeTargetFuncName(user_node->cast()); + if (name == prim::kPrimUpdateState->name()) { + FindDestOps(user_node, dst_node_list, false); + continue; + } + if (IsControlEdgeNode(user_node)) { + src_node_list->push_back(user_node); + continue; + } + FindDestOps(user_node, src_node_list, false); + } + + // add to cache + for (auto &dest : *dst_node_list) { + for (auto &src : *src_node_list) { + AddEdgeToCache(src, dest); + } + } +} + +void DfGraphConvertor::FindDestOps(const AnfNodePtr &node, const std::shared_ptr> &node_list, + bool top) { + auto func_graph = node->func_graph(); + MS_EXCEPTION_IF_NULL(func_graph); + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + func_graph->set_manager(mng); + } + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + auto users = manager->node_users()[node]; + for (const auto &iter : users) { + auto user_node = iter.first; + if (IsControlEdgeNode(user_node)) { + if (!top) { + node_list->push_back(user_node); + } + } else { + FindDestOps(user_node, node_list, false); + } + } +} + +void DfGraphConvertor::AutoMonadCollectInput(const AnfNodePtr &node) { + if (!IsSourceEdgeNode(node)) { + return; + } + + // Add control edge if contain monad input. + std::string name = GetCNodeTargetFuncName(node->cast()); + if (name == prim::kPrimLoad->name()) { + AddEdgeForLoad(node); + } else { + auto src_ops = ToOperatorPtr(node); + if (src_ops != nullptr) { + // Find dest ops list + std::shared_ptr> dst_node_list = std::make_shared>(); + FindDestOps(node, dst_node_list, true); + for (auto &dest : *dst_node_list) { + AddEdgeToCache(node, dest); + } + } + } +} + +void DfGraphConvertor::AutoMonadSetInput(const AnfNodePtr &node) { + if (monad_control_edge_cache_.find(node) == monad_control_edge_cache_.end()) { + return; + } + + auto src_ops = ToOperatorPtr(node); + if (src_ops != nullptr) { + for (auto &dest : monad_control_edge_cache_[node]) { + auto dest_ops = ToOperatorPtr(dest); + if (dest_ops == nullptr) { + continue; + } + (void)dest_ops->AddControlInput(*src_ops); +#ifdef DRAW_GE_GRAPH + compute_sout_ << op_draw_name_[node.get()] << " -> " << op_draw_name_[dest.get()] << "[style=\"dotted\"]" << endl; +#endif + } + } +} + +void DfGraphConvertor::AutoMonadSetControlInput(const AnfNodePtr &node) { + AutoMonadCollectInput(node); + AutoMonadSetInput(node); +} + +void DfGraphConvertor::SetOpControlInput(const AnfNodePtr &node) { + AutoMonadSetControlInput(node); if (control_depend_cache_.find(node.get()) == control_depend_cache_.end()) { return; } @@ -965,6 +1156,46 @@ void DfGraphConvertor::SetOpControlInput(const AnfNodePtr node) { const std::vector trans_var_list = {string(kNameAssign), string(kNameAssignAdd), string(kNameAssignSub)}; +AnfNodePtr DfGraphConvertor::ParseLoadInput(const CNodePtr &cnode) { + if (cnode->inputs().size() < 3) { + MS_LOG(EXCEPTION) << "input size error, " << cnode->ToString(); + } + const size_t para_index = 1; + return cnode->input(para_index); +} + +void DfGraphConvertor::SetTupleOpInput(const OpAdapterPtr &adpt, const CNodePtr &node, const AnfNodePtr &pred, + const OperatorPtr &src, int index) { + std::shared_ptr> handler_vec = tuple_out_handle_cache_[pred.get()]; + std::shared_ptr> handler_vec_without_monad = std::make_shared>(); + bool with_monad = false; + for (auto &handler : *handler_vec) { + // when tuple with monad type element, the handler operator is nullptr, should be ignored. + if (handler.op == nullptr) { + if ((handler.node != nullptr) && !HasAbstractMonad(handler.node)) { + MS_LOG(WARNING) << "Unsupported node in tuple : " << node->ToString(); + } + continue; + } + with_monad = true; + handler_vec_without_monad->push_back(handler); + } + int ret = adpt->setInput(src, index, handler_vec_without_monad); + + if ((ret == 0) && pred->isa() && (pred->cast()->inputs().size() == handler_vec->size() + 1)) { + for (unsigned int j = 0; j < handler_vec_without_monad->size(); j++) { + AnfNodePtr input_node = pred->cast()->input(j + 1); + if (with_monad) { + input_node = handler_vec_without_monad->at(j).node; + } + compute_sout_ << op_draw_name_[input_node.get()] << " -> " << op_draw_name_[node.get()] << ":" << index << endl; + AddGraphConstInput(handler_vec_without_monad->at(j).op); + } + return; + } + MS_LOG(WARNING) << "This anf node is not supported as a tuple item : " << node->ToString(); +} + void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) { OperatorPtr src = Convert(node); int case_flag = 0; @@ -983,13 +1214,24 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node pred = inputs[i]; } - while (pred->isa() && GetCNodeTargetFuncName(pred->cast()) == "Depend") { + while (pred->isa() && GetCNodeTargetFuncName(pred->cast()) == prim::kPrimDepend->name()) { pred = pred->cast()->input(1); } - // skip the None input - if (IsValueNode(pred)) { + + // skip input of UMonad, IOMonad + if (IsValueNode(pred) || IsValueNode(pred)) { continue; } + + // skip input of the None, Load, UpdateState + if (IsValueNode(pred) || IsPrimitiveCNode(pred, prim::kPrimUpdateState)) { + continue; + } + + if (IsPrimitiveCNode(pred, prim::kPrimLoad)) { + pred = ParseLoadInput(pred->cast()); + } + // transform "Const" op to "Variable" op when the next node is "Assign" op. std::string c_name = GetCNodeTargetFuncName(node); auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name); @@ -1010,10 +1252,11 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node vars_[name] = variable; } } + int index = SizeToInt(i); // find in out_hadnle_cache_ first auto it = out_handle_cache_.find(pred.get()); if (it != out_handle_cache_.end()) { - int ret = adpt->setInput(src, SizeToInt(i), it->second); + int ret = adpt->setInput(src, index, it->second); if (ret == 0) { if (pred->isa() && GetCNodeTargetFuncName(pred->cast()) == prim::kTupleGetItem) { compute_sout_ << op_draw_name_[pred->cast()->input(1).get()] << " -> " << op_draw_name_[node.get()] @@ -1027,20 +1270,10 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node AddGraphConstInput(it->second.op); } } else if (tuple_out_handle_cache_.find(pred.get()) != tuple_out_handle_cache_.end()) { - std::shared_ptr> handler_vec = tuple_out_handle_cache_[pred.get()]; - int ret = adpt->setInput(src, SizeToInt(i), handler_vec); - if ((ret == 0) && pred->isa() && (pred->cast()->inputs().size() == handler_vec->size() + 1)) { - for (unsigned int j = 0; j < handler_vec->size(); j++) { - compute_sout_ << op_draw_name_[pred->cast()->input(j + 1).get()] << " -> " - << op_draw_name_[node.get()] << ":" << i << endl; - AddGraphConstInput(handler_vec->at(j).op); - } - } else { - MS_LOG(WARNING) << "Convert tuple node setInput failed : " << node->ToString(); - } + SetTupleOpInput(adpt, node, pred, src, index); } else { auto op = Convert(pred); - int ret = adpt->setInput(src, SizeToInt(i), op); + int ret = adpt->setInput(src, index, op); if (ret == 0) { compute_sout_ << op_draw_name_[pred.get()] << " -> " << op_draw_name_[node.get()] << ":" << i << endl; AddGraphConstInput(op); @@ -1079,15 +1312,15 @@ void DfGraphConvertor::ProcessSubgraph(AnfNodePtr node, const std::vectorcast()->input(1)->cast(); FuncGraphPtr anf_graph = graph_node->value()->cast(); - DfGraphConvertor convertor(anf_graph); - convertor.use_inputs_ = true; - convertor.inputs_ = inputs; - (void)convertor.ConvertAllNode().BuildGraph(); + DfGraphConvertor converter(anf_graph); + converter.use_inputs_ = true; + converter.inputs_ = inputs; + (void)converter.ConvertAllNode().BuildGraph(); std::string name = graph_node->ToString() + "_ge_graph.dot"; if (MsContext::GetInstance()->get_param(MS_CTX_SAVE_GRAPHS_FLAG)) { - convertor.DrawComputeGraph(name); + converter.DrawComputeGraph(name); } - branches_map_[node.get()] = *(convertor.df_graph_); + branches_map_[node.get()] = *(converter.df_graph_); } // Update GE op's shape and type info @@ -1123,8 +1356,9 @@ OperatorPtr DfGraphConvertor::Convert(const AnfNodePtr node) { return op_cache_[node.get()]; } - // do not convert primitive node - if (IsValueNode(node)) { + // do not convert primitive node, Load, UpdateState + if (IsValueNode(node) || IsPrimitiveCNode(node, prim::kPrimLoad) || + IsPrimitiveCNode(node, prim::kPrimUpdateState)) { return nullptr; } @@ -1136,10 +1370,13 @@ OperatorPtr DfGraphConvertor::Convert(const AnfNodePtr node) { return ConvertParameter(node); } if (node->isa()) { + if (IsValueNode(node)) { + return nullptr; + } return ConvertValueNode(node->cast()); } - MS_LOG(ERROR) << "Invalide AnfNode"; + MS_LOG(ERROR) << "Invalid AnfNode"; error_ = INVALID_ARGUMENT; return nullptr; } @@ -1149,14 +1386,16 @@ void DfGraphConvertor::ConvertMakeTuple(const CNodePtr node) { // convert each tuple item to a OutHandler for (size_t i = 1; i < node->inputs().size(); i++) { AnfNodePtr item = node->input(i); + if (IsPrimitiveCNode(item, prim::kPrimLoad)) { + item = ParseLoadInput(item->cast()); + } OperatorPtr op = Convert(item); if (op != nullptr) { - tuple_items->emplace_back(OutHandler(op, "")); + tuple_items->emplace_back(OutHandler(op, "", item)); } else if (out_handle_cache_.find(item.get()) != out_handle_cache_.end()) { tuple_items->push_back(out_handle_cache_[item.get()]); } else { - MS_LOG(WARNING) << "This anf node is not supported as a tuple item : " << item->ToString(); - return; + tuple_items->push_back(OutHandler(nullptr, "", item)); } } @@ -1520,31 +1759,25 @@ void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) { bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) { // ignore apply node of return - if (name == "return" || name == "Depend") { - return false; - } - - if (name == "" && GetCNodeFuncName(node) == "switch_layer") { - return false; - } - - if (name == "Partial") { + if (name == "" || name == prim::kPrimReturn->name() || name == prim::kPrimDepend->name() || + name == prim::kPrimSwitchLayer->name() || name == prim::kPrimPartial->name()) { return false; } // make_tuple is used for a dynamic_input, convert it to a vector of OutHandlers - if (name == "make_tuple") { + if (name == prim::kPrimMakeTuple->name()) { ConvertMakeTuple(node); return false; } // As for nodes with multi outputs, convert tuple_getitem to OutHandle - if (name == prim::kTupleGetItem) { + if (name == prim::kPrimTupleGetItem->name()) { ConvertTupleGetItem(node); return false; } - if (name == "ControlDepend") { + // ControlDepend + if (name == prim::kPrimControlDepend->name()) { ConvertControlDependNode(node); return false; } diff --git a/mindspore/ccsrc/transform/graph_ir/convert.h b/mindspore/ccsrc/transform/graph_ir/convert.h index ec46edd15a..c40efd876a 100644 --- a/mindspore/ccsrc/transform/graph_ir/convert.h +++ b/mindspore/ccsrc/transform/graph_ir/convert.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -163,7 +164,7 @@ class DfGraphConvertor { void TraceOutputFromParameter(const AnfNodePtr &anf_out); void TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out); void SetNodeInput(AnfNodePtr node); - void SetOpControlInput(const AnfNodePtr node); + void SetOpControlInput(const AnfNodePtr &node); void UpdateOpDesc(AnfNodePtr node); void SetSubgraph(AnfNodePtr node); void ProcessSubgraph(AnfNodePtr node, const std::vector &inputs); @@ -171,6 +172,19 @@ class DfGraphConvertor { void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt); void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; void AddGraphConstInput(const OperatorPtr &op); + OperatorPtr ToOperatorPtr(const AnfNodePtr &node); + bool IsSourceEdgeNode(const AnfNodePtr &node); + bool IsControlEdgeNode(const AnfNodePtr &node); + void AddEdgeForLoad(const AnfNodePtr &node); + void AddEdgeToCache(const AnfNodePtr &src, const AnfNodePtr &dest); + void FindDestOps(const AnfNodePtr &node, const std::shared_ptr> &node_list, bool top); + AnfNodePtr ParseLoadInput(const CNodePtr &cnode); + void AutoMonadSetControlInput(const AnfNodePtr &node); + void AutoMonadCollectInput(const AnfNodePtr &node); + void AutoMonadSetInput(const AnfNodePtr &node); + void SetTupleOpInput(const OpAdapterPtr &adpt, const CNodePtr &node, const AnfNodePtr &pred, const OperatorPtr &src, + int index); + void UpdateTupleOutCache(void); std::shared_ptr anf_graph_{nullptr}; std::shared_ptr df_graph_{nullptr}; @@ -181,6 +195,7 @@ class DfGraphConvertor { std::unordered_map branches_map_; std::unordered_map op_cache_; std::unordered_map> control_depend_cache_; + std::unordered_map> monad_control_edge_cache_; /* record "tuple_getitem"<->"out_handler" mapping */ std::unordered_map out_handle_cache_; /* record "make_tuple"<->"out_handler vector" mapping */ diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter.cc b/mindspore/ccsrc/transform/graph_ir/op_adapter.cc index 8c691f314d..f9794363b2 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter.cc +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter.cc @@ -140,8 +140,7 @@ Status OpAdapterImpl::SetCustomOpInput(const CusOperatorPtr &op, int index, cons Status OpAdapterImpl::SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input) { MS_EXCEPTION_IF_NULL(op); auto it = input_map_.find(index); - if (it != input_map_.end()) { - MS_EXCEPTION_IF_NULL(input); + if (input != nullptr && it != input_map_.end()) { MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << it->second.name; it->second.set_op(op, input); return SUCCESS; diff --git a/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h b/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h index ae63a45960..947c81941f 100644 --- a/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h +++ b/mindspore/ccsrc/transform/graph_ir/op_adapter_base.h @@ -68,8 +68,10 @@ using CustomOperator = ge::CustomOperator; struct OutHandler { OperatorPtr op; std::string out; - OutHandler() : op(nullptr), out("") {} - OutHandler(const OperatorPtr &op, const std::string out) : op(op), out(out) {} + AnfNodePtr node; + OutHandler() : op(nullptr), out(""), node(nullptr) {} + OutHandler(const OperatorPtr &op, const std::string out, const AnfNodePtr &node = nullptr) + : op(op), out(out), node(node) {} }; struct ControlEdge { diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index 6c885d2afa..d1d5eacf84 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -179,8 +179,8 @@ bool SameSubgraph(const AnfNodePtr &root1, const AnfNodePtr &root2, FuncGraphPai AnfNodePtr node2 = todo.top().second; bool condition = false; - std::vector s1 = SuccIncoming(node1); - std::vector s2 = SuccIncoming(node2); + const auto &s1 = GetInputs(node1); + const auto &s2 = GetInputs(node2); if (s1.size() != s2.size()) { return false; diff --git a/mindspore/ccsrc/utils/convert_utils_py.cc b/mindspore/ccsrc/utils/convert_utils_py.cc index 0dadae5cae..e3816d88be 100644 --- a/mindspore/ccsrc/utils/convert_utils_py.cc +++ b/mindspore/ccsrc/utils/convert_utils_py.cc @@ -159,6 +159,8 @@ py::object ValuePtrToPyData(const ValuePtr &value) { } else if (value->isa()) { // FuncGraph is not used in the backend, return None ret = py::none(); + } else if (value->isa()) { + ret = py::none(); } else { MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData."; } @@ -377,6 +379,27 @@ AbstractBasePtr PyList2DynamicShapeTensor(const py::object &shape_obj, const py: return tensor; } +static bool IsMonadType(const py::object &type_obj) { + if (py::isinstance(type_obj)) { + auto type = type_obj.cast(); + return type->isa(); + } + return false; +} + +static AbstractBasePtr ToMonadAbstract(const py::object &type_obj) { + if (py::isinstance(type_obj)) { + auto type = type_obj.cast(); + if (type->isa()) { + return kUMonad->ToAbstract(); + } + if (type->isa()) { + return kIOMonad->ToAbstract(); + } + } + MS_LOG(EXCEPTION) << "Not a monad type object: " << py::str(type_obj); +} + AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, const py::object &output) { if ((py::isinstance(shape_obj) || py::isinstance(shape_obj)) && py::isinstance(type_obj)) { @@ -413,6 +436,9 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py // AbstractNone indicates there is no output for this CNode node. auto abstract_none = std::make_shared(); return abstract_none; + } else if (IsMonadType(type_obj)) { + // Return monad abstract if it is monad type. + return ToMonadAbstract(type_obj); } else { // When sparse enabled, the undetermined might be raised and eliminated in opt passes auto context = MsContext::GetInstance(); diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 4facb834b7..c9d0979bc1 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -422,6 +422,9 @@ constexpr auto kControlDependMode = "depend_mode"; constexpr auto kRealInputIndexInDepend = 1; constexpr auto kDependAttachNodeIndex = 2; constexpr auto kDependInputSize = 3; +// index define of UpdateState +constexpr auto kUpdateStateStateInput = 1; +constexpr auto kUpdateStateRealInput = 2; // format constexpr auto kOpFormat_DEFAULT = "DefaultFormat"; constexpr auto kOpFormat_NC1KHKWHWC0 = "NC1KHKWHWC0"; @@ -483,11 +486,12 @@ const std::set kOptOperatorSet = {kMomentumOpName, kSparseApplyFtrlV2Name, kSGDName, kLARSUpdateName, - kPullOpName, kCombineMomentumWeightOpName, kCombineMomentumOpName, kSparseApplyProximalAdagradOpName}; +const std::set kPosteriorOperatorSet = {kPullOpName}; + const std::set kHWSpecialFormatSet = { kOpFormat_FRACTAL_Z_3D, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04, kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_NDC1HWC0, kOpFormat_FRAC_Z}; diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index e6ad847777..11f090de1d 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -162,9 +162,6 @@ int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const GraphSegmentPt } } AddExternal(result); - for (auto &o : result.outputs) { - Push(o); - } return RET_SUCCESS; } @@ -377,6 +374,9 @@ void CompileGraph::AddExternal(const LinConvertResult &result) { args.emplace_back(Ref(result.inputs[i])); } AddInst(Instruction::kExternal, args); + for (auto &out : result.outputs) { + Push(out); + } } void TraverseGraphMap( diff --git a/mindspore/common/__init__.py b/mindspore/common/__init__.py index 39afd6410b..c60c79ccf5 100644 --- a/mindspore/common/__init__.py +++ b/mindspore/common/__init__.py @@ -14,6 +14,7 @@ # ============================================================================ """Top-level reference to dtype of common module.""" from . import dtype +from . import monad from .api import ms_function from .dtype import * from .parameter import Parameter, ParameterTuple @@ -27,5 +28,6 @@ __all__.extend([ 'ms_function', # api 'Parameter', 'ParameterTuple', # parameter "dtype", + 'monad', "set_seed", "get_seed" # random seed ]) diff --git a/mindspore/nn/_graph_kernels/__init__.py b/mindspore/common/monad.py similarity index 68% rename from mindspore/nn/_graph_kernels/__init__.py rename to mindspore/common/monad.py index 356d183766..3ea7728b6a 100644 --- a/mindspore/nn/_graph_kernels/__init__.py +++ b/mindspore/common/monad.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -""" -GraphKernel. +"""Define Monad default value.""" +from .._c_expression import IOMonad, UMonad -GraphKernel provides a unified style to express graph and kernel for user. -It breaks the boundary between graph and kernel and provides more opportunities to do compile optimization. -""" -from .graph_kernels import LambUpdateWithLR, LambNextMV +# Universe monad default value. +U = UMonad() -__all__ = ['LambUpdateWithLR', 'LambNextMV'] +# IO monad default value. +IO = IOMonad() + +__all__ = ['U', 'IO'] diff --git a/mindspore/core/abstract/abstract_function.h b/mindspore/core/abstract/abstract_function.h index ac18fa689f..ae959849a8 100644 --- a/mindspore/core/abstract/abstract_function.h +++ b/mindspore/core/abstract/abstract_function.h @@ -58,6 +58,11 @@ class AbstractFuncUnion : public AbstractFunction { bool operator==(const AbstractFunction &other) const override; std::size_t hash() const override; AbstractFunctionPtr Copy() const override { MS_LOG(EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; } + bool HasIsolateNodesFlag() const override { + bool flag = std::any_of(func_list_.cbegin(), func_list_.cend(), + [](const AbstractFunctionPtr &func) { return func->HasIsolateNodesFlag(); }); + return flag; + } private: AbstractFuncAtomPtrList func_list_; @@ -126,13 +131,15 @@ class FuncGraphAbstractClosure : public AbstractFuncAtom { std::string ToString() const override; + bool HasIsolateNodesFlag() const override { return !func_graph_->isolate_nodes().empty(); } + private: FuncGraphPtr func_graph_; AnalysisContextPtr context_; // To discriminate different usage of same graph by using this tracking_id, // so different tracking_id will produce different FuncGraphAbstractClosure, // different FuncGraphEvaluator. - // Espcecially usefull for recursive func graph call, so it will not mess up + // Espcecially useful for recursive func graph call, so it will not mess up // the graph_context_ in FuncGraphEvaluator. // Notes: Be careful to use nullptr for this variable. // store it as weak_ptr to break reference cycle. @@ -195,12 +202,16 @@ class PartialAbstractClosure : public AbstractFuncAtom { std::size_t hash() const override; std::string ToString() const override; + bool HasIsolateNodesFlag() const override { return isolate_nodes_flag_; } + void SetIsolateNodesFlag(bool flag) { isolate_nodes_flag_ = flag; } private: AbstractFuncAtomPtr fn_; AbstractBasePtrList args_spec_list_; // The CNode which this PartialAbstractClosure evaluated from. AnfNodeWeakPtr node_; + // If the bound fn_ has isolate ndoes or arguments evaluated from function has isolate nodes. + bool isolate_nodes_flag_{false}; }; using PartialAbstractClosurePtr = std::shared_ptr; diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index ca9a37b570..9a4ec269a6 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -1210,5 +1210,42 @@ std::string AbstractSparseTensor::ToString() const { << ", dense_shape: " << dense_shape_->ToString(); return buffer.str(); } + +AbstractBasePtr AbstractUMonad::Join(const AbstractBasePtr &other) { + MS_EXCEPTION_IF_NULL(other); + if (other->isa()) { + return shared_from_base(); + } + MS_EXCEPTION(TypeError) << "Type join failed, type1 = " << GetTypeTrack()->ToString() + << ", type2 = " << other->GetTypeTrack()->ToString(); +} + +bool AbstractUMonad::operator==(const AbstractUMonad &) const { return true; } + +bool AbstractUMonad::operator==(const AbstractBase &other) const { + if (&other == this) { + return true; + } + return other.isa(); +} + +AbstractBasePtr AbstractIOMonad::Join(const AbstractBasePtr &other) { + MS_EXCEPTION_IF_NULL(other); + if (other->isa()) { + return shared_from_base(); + } + MS_EXCEPTION(TypeError) << "Type join failed, type1 = " << GetTypeTrack()->ToString() + << ", type2 = " << other->GetTypeTrack()->ToString(); +} + +bool AbstractIOMonad::operator==(const AbstractIOMonad &) const { return true; } + +bool AbstractIOMonad::operator==(const AbstractBase &other) const { + if (&other == this) { + return true; + } + return other.isa(); +} + } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index 6e582d6075..73d55e0c46 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -207,6 +207,8 @@ class AbstractFunction : public AbstractBase { virtual AnfNodePtr tracking_id() const { return nullptr; } virtual void set_tracking_id(AnfNodePtr) {} virtual AnalysisContextPtr context() const { return nullptr; } + // Function which itself has IsolateNodes, not include used function or HOF. + virtual bool HasIsolateNodesFlag() const { return false; } }; using AbstractFunctionPtrList = std::vector; @@ -615,6 +617,7 @@ class AbstractRef : public AbstractTensor { } return std::make_shared(ref_key_->Clone(), abs_tensor); } + AbstractBasePtr CloneAsTensor() const { return AbstractTensor::Clone(); } std::string ToString() const override; inline AbstractTensorPtr ref() { return shared_from_base(); } inline AbstractBasePtr ref_key() const { return ref_key_; } @@ -707,6 +710,53 @@ class AbstractSparseTensor : public AbstractUndetermined { AbstractTensorPtr values_; AbstractTuplePtr dense_shape_; }; + +class AbstractMonad : public AbstractBase { + public: + ~AbstractMonad() override = default; + MS_DECLARE_PARENT(AbstractMonad, AbstractBase) + + std::size_t hash() const override { return hash_combine({tid()}); } + TypePtr BuildType() const override { return GetTypeTrack(); } + AbstractBasePtr Broaden(uint8_t config) const override { return AbstractBase::Broaden(config); } + AbstractBasePtr Join(const AbstractBasePtr &other) override = 0; + std::string ToString() const override { + std::ostringstream buffer; + buffer << type_name() << "(" << GetValueTrack()->ToString() << ")"; + return buffer.str(); + } + + protected: + AbstractMonad(const ValuePtr &value, const TypePtr &type) : AbstractBase(value, type) {} +}; +using AbstractMonadPtr = std::shared_ptr; + +class AbstractUMonad : public AbstractMonad { + public: + explicit AbstractUMonad(const ValuePtr &value = kUMonad) : AbstractMonad(value, kUMonadType) {} + ~AbstractUMonad() override = default; + MS_DECLARE_PARENT(AbstractUMonad, AbstractMonad) + + AbstractBasePtr Clone() const override { return std::make_shared(GetValueTrack()); } + AbstractBasePtr Join(const AbstractBasePtr &other) override; + bool operator==(const AbstractUMonad &other) const; + bool operator==(const AbstractBase &other) const override; +}; +using AbstractUMonadPtr = std::shared_ptr; + +class AbstractIOMonad : public AbstractMonad { + public: + explicit AbstractIOMonad(const ValuePtr &value = kIOMonad) : AbstractMonad(value, kIOMonadType) {} + ~AbstractIOMonad() override = default; + MS_DECLARE_PARENT(AbstractIOMonad, AbstractMonad) + + AbstractBasePtr Clone() const override { return std::make_shared(GetValueTrack()); } + AbstractBasePtr Join(const AbstractBasePtr &other) override; + bool operator==(const AbstractIOMonad &other) const; + bool operator==(const AbstractBase &other) const override; +}; +using AbstractIOMonadPtr = std::shared_ptr; + } // namespace abstract } // namespace mindspore #endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_VALUE_H_ diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 98e9e8c745..2ced61888c 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -182,6 +182,8 @@ AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const Primitive const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, @@ -203,6 +205,9 @@ AbstractBasePtr InferImplRowTensorGetIndices(const AnalysisEnginePtr &, const Pr const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); + AbstractBasePtr InferImplUniqueGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/param_validator.cc b/mindspore/core/abstract/param_validator.cc index 0a88482b73..73aface634 100644 --- a/mindspore/core/abstract/param_validator.cc +++ b/mindspore/core/abstract/param_validator.cc @@ -208,5 +208,16 @@ std::string CheckAttrStringSet(const std::string &op, const ValuePtr &attr, cons } return attr_val; } + +void CheckRequiredArgsSize(const std::string &op, const mindspore::abstract::AbstractBasePtrList &args_spec_list, + size_t size_expect) { + if (args_spec_list.size() < size_expect) { + MS_LOG(EXCEPTION) << op << " required input args size " << size_expect << ", but got " << args_spec_list.size(); + } + for (size_t i = 0; i < size_expect; i++) { + MS_EXCEPTION_IF_NULL(args_spec_list[i]); + } +} + } // namespace abstract } // namespace mindspore diff --git a/mindspore/core/abstract/param_validator.h b/mindspore/core/abstract/param_validator.h index 838df0d45a..11328d910e 100644 --- a/mindspore/core/abstract/param_validator.h +++ b/mindspore/core/abstract/param_validator.h @@ -61,6 +61,8 @@ std::vector CheckAttrIntOrTuple(const std::string &op, const ValuePtr & std::string CheckAttrStringSet(const std::string &op, const ValuePtr &attr, const std::string &attr_name, const std::set &val_set); +void CheckRequiredArgsSize(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t size_expect); + template struct ReportNameTraits {}; diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 2e1e7f6f47..7244ba9d0c 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -396,7 +396,7 @@ AbstractBasePtr InferImplUnsortedSegmentMin(const AnalysisEnginePtr &, const Pri AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 3); + CheckRequiredArgsSize(op_name, args_spec_list, 3); auto x = CheckArg(op_name, args_spec_list, 0); MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(x->shape()); @@ -410,7 +410,7 @@ AbstractBasePtr InferImplScatterAdd(const AnalysisEnginePtr &, const PrimitivePt AbstractBasePtr InferImplScatterUpdate(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 3); + CheckRequiredArgsSize(op_name, args_spec_list, 3); auto x = CheckArg(op_name, args_spec_list, 0); MS_EXCEPTION_IF_NULL(x); MS_EXCEPTION_IF_NULL(x->shape()); diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index 916253ee1d..ac36dc98b8 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -406,7 +406,7 @@ AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &p AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: three tensors(doutput, input, filters). - CheckArgsSize(primitive->name(), args_spec_list, 3); + CheckRequiredArgsSize(primitive->name(), args_spec_list, 3); return args_spec_list[1]->Broaden(); } @@ -580,7 +580,7 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti // Inputs: a tuple and a tensor. // Outputs: mask. const std::string op_name = primitive->name(); - CheckArgsSize(op_name, args_spec_list, 2); + CheckRequiredArgsSize(op_name, args_spec_list, 2); AbstractTuplePtr x_shape = CheckArg(op_name, args_spec_list, 0); AbstractTensorPtr keep_prob = CheckArg(op_name, args_spec_list, 1); @@ -627,7 +627,7 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { - CheckArgsSize(primitive->name(), args_spec_list, 5); + CheckRequiredArgsSize(primitive->name(), args_spec_list, 5); AbstractBasePtrList elements; for (size_t i = 0; i < 3; ++i) { elements.push_back(args_spec_list[i]->Clone()->Broaden()); @@ -637,7 +637,7 @@ AbstractBasePtr InferImplSparseApplyFtrl(const AnalysisEnginePtr &, const Primit AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { - CheckArgsSize(primitive->name(), args_spec_list, 7); + CheckRequiredArgsSize(primitive->name(), args_spec_list, 7); AbstractBasePtrList elements; for (size_t i = 0; i < 2; ++i) { elements.push_back(args_spec_list[i]->Clone()->Broaden()); @@ -647,7 +647,7 @@ AbstractBasePtr InferImplSparseApplyProximalAdagrad(const AnalysisEnginePtr &, c AbstractBasePtr InferImplSGD(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { - CheckArgsSize(primitive->name(), args_spec_list, 6); + CheckRequiredArgsSize(primitive->name(), args_spec_list, 6); AbstractBasePtrList elements; elements.push_back(args_spec_list[0]->Clone()->Broaden()); return std::make_shared(elements); diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 833b7df205..4811e17208 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -165,7 +165,10 @@ AbstractBasePtr InferImplStateSetItem(const AnalysisEnginePtr &, const Primitive AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { if (args_spec_list.empty()) { - MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at lest 1, but got 0"; + MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at least 1, but got 0"; + } + if (primitive->GetAttr(ATTR_NO_BROADEN) != nullptr) { + return args_spec_list[0]; } auto depends = args_spec_list[0]->Broaden(); // For scalar, need to set value to kAnyValue, because broaden scalar will not change the value. @@ -175,6 +178,14 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p return depends; } +AbstractBasePtr InferImplUpdateState(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + if (args_spec_list.empty()) { + MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at least 1, but got 0"; + } + return args_spec_list[0]->Broaden(); +} + AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // args: Two objects of a subclass of AbstractBase @@ -282,6 +293,18 @@ AbstractBasePtr InferImplRowTensorGetDenseShape(const AnalysisEnginePtr &, const return row_tensor->dense_shape(); } +AbstractBasePtr InferImplRowTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: row tensor and tensor. + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto row_tensor = CheckArg(op_name, args_spec_list, 0); + auto tensor = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(row_tensor->dense_shape()); + MS_EXCEPTION_IF_NULL(tensor->shape()); + return args_spec_list[0]; +} + AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: two tensors and a tuple. diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 642fcb9257..2b348ba0e5 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -163,6 +163,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimGetRefValue, {InferImplGetRefValue, true}}, {prim::kPrimStateSetItem, {InferImplStateSetItem, true}}, {prim::kPrimDepend, {InferImplDepend, true}}, + {prim::kPrimUpdateState, {InferImplUpdateState, true}}, {prim::kPrimControlDepend, {InferImplControlDepend, true}}, // Debug {prim::kPrimDebug, {InferImplDebug, true}}, @@ -178,6 +179,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimRowTensorGetValues, {InferImplRowTensorGetValues, true}}, {prim::kPrimRowTensorGetIndices, {InferImplRowTensorGetIndices, true}}, {prim::kPrimRowTensorGetDenseShape, {InferImplRowTensorGetDenseShape, true}}, + {prim::kPrimRowTensorAdd, {InferImplRowTensorAdd, false}}, // Comm Ops {prim::kPrimAllReduce, {InferImplAllReduce, true}}, {prim::kPrimBroadcast, {InferImplBroadcast, true}}, diff --git a/mindspore/core/base/base.h b/mindspore/core/base/base.h index e43b042cfa..1bc579d620 100644 --- a/mindspore/core/base/base.h +++ b/mindspore/core/base/base.h @@ -102,7 +102,7 @@ inline T *cast(U *source) { template < typename T, typename U, typename std::enable_if::value && std::is_base_of::value, T>::type * = nullptr> -inline std::shared_ptr dyn_cast(const std::shared_ptr r) { +inline std::shared_ptr dyn_cast(const std::shared_ptr &r) { if (r != nullptr && r->template isa()) { return std::static_pointer_cast(r); } else { diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index 19d7d2464a..d5190f147b 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -20,11 +20,18 @@ #include #include #include +#include #include "ir/anf.h" #include "ir/primitive.h" +#include "utils/flags.h" namespace mindspore { namespace prim { +inline const ValuePtr kValueOne = std::make_shared(1); +inline const std::unordered_map kSideEffectPropagate = { + {mindspore::GRAPH_FLAG_SIDE_EFFECT_PROPAGATE, kValueOne}, +}; + constexpr auto kGather = "Gather"; // Arithmetic constexpr auto kScalarAdd = "ScalarAdd"; @@ -312,6 +319,7 @@ inline const PrimitivePtr kPrimMakeRowTensor = std::make_shared("Make inline const PrimitivePtr kPrimRowTensorGetValues = std::make_shared("RowTensorGetValues"); inline const PrimitivePtr kPrimRowTensorGetIndices = std::make_shared("RowTensorGetIndices"); inline const PrimitivePtr kPrimRowTensorGetDenseShape = std::make_shared("RowTensorGetDenseShape"); +inline const PrimitivePtr kPrimRowTensorAdd = std::make_shared("RowTensorAdd"); // SparseTensor inline const PrimitivePtr kPrimMakeSparseTensor = std::make_shared("MakeSparseTensor"); @@ -417,10 +425,12 @@ inline const PrimitivePtr kPrimGpuConvertToDynamicShape = std::make_shared("ErrorOnDynamicShapeInput"); // Other miscellaneous -inline const PrimitivePtr kPrimDepend = std::make_shared("Depend"); +inline const PrimitivePtr kPrimDepend = std::make_shared("Depend", kSideEffectPropagate); inline const PrimitivePtr kPrimReformat = std::make_shared("Reformat"); -inline const PrimitivePtr kPrimPartial = std::make_shared("Partial"); -inline const PrimitivePtr kPrimIdentity = std::make_shared("identity"); +inline const PrimitivePtr kPrimLoad = std::make_shared("Load"); +inline const PrimitivePtr kPrimUpdateState = std::make_shared("UpdateState"); +inline const PrimitivePtr kPrimPartial = std::make_shared("Partial", kSideEffectPropagate); +inline const PrimitivePtr kPrimIdentity = std::make_shared("identity", kSideEffectPropagate); inline const PrimitivePtr kPrimHookBackward = std::make_shared("HookBackward"); inline const PrimitivePtr kPrimPrintShapeType = std::make_shared("PrintShapeType"); inline const PrimitivePtr kPrimSameTypeShape = std::make_shared("SameTypeShape"); @@ -438,6 +448,7 @@ inline const PrimitivePtr kPrimCustomPredict = std::make_shared("Cust inline const PrimitivePtr kPrimPriorBox = std::make_shared("PriorBox"); inline const PrimitivePtr kPrimQuantDTypeCast = std::make_shared("QuantDTypeCast"); inline const PrimitivePtr kPrimWhile = std::make_shared("While"); +inline const PrimitivePtr kPrimPull = std::make_shared("Pull"); // Structures inline const PrimitivePtr kPrimMakeList = std::make_shared("make_list"); @@ -462,7 +473,7 @@ inline const PrimitivePtr kPrimGetRefValue = std::make_shared("get_re // Other primitive not used by backend but used in core; inline const PrimitivePtr kPrimStateSetItem = std::make_shared("state_setitem"); -inline const PrimitivePtr kPrimJ = std::make_shared("J"); +inline const PrimitivePtr kPrimJ = std::make_shared("J", kSideEffectPropagate); // Used to build graph which have keyword arguments inline const PrimitivePtr kPrimExtractKeywordArg = std::make_shared("extract_keyword_arg"); diff --git a/mindspore/core/base/effect_info.h b/mindspore/core/base/effect_info.h new file mode 100644 index 0000000000..32ffc37dd8 --- /dev/null +++ b/mindspore/core/base/effect_info.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_EFFECT_INFO_H_ +#define MINDSPORE_CORE_EFFECT_INFO_H_ + +namespace mindspore { + +struct EffectInfo { + enum State : unsigned char { + kUnknown = 0, + kDetecting = 1, + kDetected = 2, + }; + State state = kUnknown; // effect info state. + bool memory = false; // memory side effects, e.g., access global variable. + bool io = false; // IO side effects, e.g., print message. + bool load = false; // load value from global variable, e.g. add(self.para, x). + + void Merge(const EffectInfo &info) { + if (info.state != EffectInfo::kDetected) { + state = EffectInfo::kDetecting; + } + memory = memory || info.memory; + io = io || info.io; + load = load || info.load; + } +}; + +// EffectInfoHolder as base class for effect info holders, such as CNode, FuncGraph, etc. +class EffectInfoHolder { + public: + // Gets effect info. + const EffectInfo &GetEffectInfo() const { return effect_info_; } + + // Set effect info. + void SetEffectInfo(const EffectInfo &info) { effect_info_ = info; } + + // Unset effect info. + void UnsetEffectInfo() { effect_info_ = {EffectInfo::kUnknown, false, false}; } + + protected: + EffectInfo effect_info_; +}; + +} // namespace mindspore + +#endif // MINDSPORE_CORE_EFFECT_INFO_H_ diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index 4bc1cfc058..33e50fa347 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include "base/core_ops.h" @@ -31,7 +32,11 @@ namespace mindspore { // namespace to support intermediate representation definition CNode::CNode(const std::vector &inputs, const FuncGraphPtr &func_graph) - : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false), output_value_(std::make_pair(nullptr, "")) {} + : AnfNode(func_graph), + inputs_(inputs), + stop_gradient_(false), + output_value_(std::make_pair(nullptr, "")), + input_tensor_num_(-1) {} // Check if CNode is an apply with the specific Primitive. bool CNode::IsApply(const PrimitivePtr &value) const { @@ -49,7 +54,20 @@ bool CNode::IsApply(const PrimitivePtr &value) const { return false; } -void CNode::set_input(size_t i, const AnfNodePtr &new_input) { inputs_[i] = new_input; } +void CNode::add_input(const AnfNodePtr &input) { + inputs_.push_back(input); + input_tensor_num_ = -1; +} + +void CNode::set_input(size_t i, const AnfNodePtr &new_input) { + inputs_[i] = new_input; + input_tensor_num_ = -1; +} + +void CNode::set_inputs(const std::vector &inputs) { + inputs_ = inputs; + input_tensor_num_ = -1; +} std::string CNode::DebugString(int recursive_level) const { std::ostringstream buffer; @@ -128,8 +146,7 @@ std::string ValueNode::fullname_with_scope() { } bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { - MS_EXCEPTION_IF_NULL(node); - auto cnode = node->cast(); + auto cnode = dyn_cast(node); if (cnode == nullptr) { return false; } @@ -171,6 +188,14 @@ std::string GetCNodeFuncName(const CNodePtr cnode) { return ""; } +FuncGraphPtr GetCNodeFuncGraph(const AnfNodePtr &node) { + auto cnode = dyn_cast(node); + if (cnode != nullptr && cnode->size() > 0) { + return GetValueNode(cnode->input(0)); + } + return nullptr; +} + bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) { if (IsValueNode(node)) { PrimitivePtr fn_value = GetValueNode(node); @@ -182,6 +207,99 @@ bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) { return false; } +bool IsPrimitiveEquals(const PrimitivePtr &prim1, const PrimitivePtr &prim2) { + if (prim1 == nullptr || prim2 == nullptr) { + return false; + } + return (prim1 == prim2) || (prim1->Hash() == prim2->Hash() && prim1->name() == prim2->name()); +} + +size_t GetAbstractMonadNum(const AbstractBasePtrList &args) { + size_t num = 0; + for (auto &arg : args) { + if (arg->isa()) { + ++num; + } + } + return num; +} + +template +bool HasAbstract(const AnfNodePtr &node) { + if (node == nullptr) { + return false; + } + const auto &abs = node->abstract(); + return (abs != nullptr && abs->isa()); +} + +bool HasAbstractMonad(const AnfNodePtr &node) { return HasAbstract(node); } + +bool HasAbstractUMonad(const AnfNodePtr &node) { return HasAbstract(node); } + +bool HasAbstractIOMonad(const AnfNodePtr &node) { return HasAbstract(node); } + +bool GetPrimitiveFlag(const PrimitivePtr &prim, const std::string &attr) { + if (prim != nullptr) { + auto flag = prim->GetAttr(attr); + if (flag && flag->isa()) { + return GetValue(flag); + } + } + return false; +} + +EffectInfo GetPrimEffectInfo(const PrimitivePtr &prim) { + bool mem = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_MEM); + bool io = GetPrimitiveFlag(prim, GRAPH_FLAG_SIDE_EFFECT_IO); + return {EffectInfo::kDetected, mem, io, false}; +} + +MonadState GetMonadState(const AnfNodePtr &node, const AnfNodePtr &skip_input) { + if (node == nullptr) { + return {}; + } + MonadState state; + size_t seen = NewSeenGeneration(); + std::queue que; + que.push(node); + while (!que.empty()) { + auto n = que.front(); + que.pop(); + + // check whether this node has been matched or should be skipped. + if (n == nullptr || n->seen_ == seen || n == skip_input) { + continue; + } + n->seen_ = seen; + + // check whether this node has monad abstract. + if (state.u == nullptr && HasAbstractUMonad(n)) { + state.u = n; + } else if (state.io == nullptr && HasAbstractIOMonad(n)) { + state.io = n; + } else { + auto cnode = dyn_cast(n); + if (cnode != nullptr) { + for (auto it = cnode->inputs().rbegin(); it != cnode->inputs().rend(); ++it) { + que.push(*it); + } + } + continue; + } + + if (state.u != nullptr && state.io != nullptr) { + return state; + } + } + return state; +} + +bool IsStateEquivalent(const MonadState &state1, const MonadState &state2) { + return (state1.u == nullptr || state2.u == nullptr || state1.u == state2.u) && + (state1.io == nullptr || state2.io == nullptr || state1.io == state2.io); +} + size_t NewSeenGeneration() { static size_t seen_generation = 0; return ++seen_generation; @@ -246,6 +364,11 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { } auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); + const std::string primitive_target = "primitive_target"; + auto ud_target = cnode->user_data(primitive_target); + if (ud_target != nullptr) { + return *ud_target.get(); + } auto attr_input = cnode->input(0); if (attr_input == nullptr) { return default_target; @@ -262,14 +385,14 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { return default_target; } auto primitive = value->cast(); - auto att_target = primitive->GetAttr("primitive_target"); + auto att_target = primitive->GetAttr(primitive_target); if (att_target != nullptr) { if (IsPrimitive(attr_input, prim::kPrimImageSummary) || IsPrimitive(attr_input, prim::kPrimScalarSummary) || IsPrimitive(attr_input, prim::kPrimTensorSummary) || IsPrimitive(attr_input, prim::kPrimHistogramSummary) || IsPrimitive(attr_input, prim::kPrimStateSetItem) || IsPrimitive(attr_input, prim::kPrimDepend) || IsPrimitive(attr_input, prim::kPrimControlDepend) || IsPrimitive(attr_input, prim::kPrimReturn) || IsPrimitive(attr_input, prim::kPrimPartial)) { - primitive->EraseAttr("primitive_target"); + primitive->EraseAttr(primitive_target); return default_target; } if (!att_target->isa()) { diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index d629ca82a0..27b2399077 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -29,6 +29,7 @@ #include "base/base.h" #include "base/user_data.h" +#include "base/effect_info.h" #include "ir/kernel_info_dev.h" #include "ir/scope.h" #include "utils/info.h" @@ -122,7 +123,7 @@ class AnfNode : public Base { const KernelInfoDevicePtr &kernel_info_ptr() { return kernel_info_; } void set_kernel_info(const KernelInfoDevicePtr &kernel_info) { kernel_info_ = kernel_info; } - AbstractBasePtr abstract() const { return abstract_; } + const AbstractBasePtr &abstract() const { return abstract_; } void set_abstract(const AbstractBasePtr &abs) { abstract_ = abs; } AbstractBasePtr intermediate_abstract() { return intermediate_abstract_; } @@ -189,6 +190,8 @@ class AnfNode : public Base { return user_data_.has(T::key); } + void CloneUserData(const AnfNodePtr &node) { user_data_ = node->user_data_; } + int64_t stage() { return stage_; } void set_stage(const int &stage) { stage_ = stage; } @@ -225,11 +228,15 @@ class AnfNode : public Base { // stop_gradient_: a flag used to stop gradient. // Using stop_gradient() to get this flag, mainly used in ad. // Using set_stop_gradient() to set this flag. -class CNode : public AnfNode { +class CNode : public AnfNode, public EffectInfoHolder { public: CNode(const std::vector &inputs, const FuncGraphPtr &func_graph); CNode(const std::vector &inputs, const VarPtr &func_graph_as_var) - : AnfNode(nullptr), inputs_(inputs), func_graph_as_var_(func_graph_as_var), stop_gradient_(false) {} + : AnfNode(nullptr), + inputs_(inputs), + func_graph_as_var_(func_graph_as_var), + stop_gradient_(false), + input_tensor_num_(-1) {} ~CNode() override = default; MS_DECLARE_PARENT(CNode, AnfNode); @@ -241,9 +248,9 @@ class CNode : public AnfNode { const size_t size() const { return inputs_.size(); } const AnfNodePtr input(size_t i) const { return inputs_[i]; } const std::vector &inputs() const { return inputs_; } - void add_input(const AnfNodePtr &input) { inputs_.push_back(input); } + void add_input(const AnfNodePtr &input); void set_input(size_t i, const AnfNodePtr &input); - void set_inputs(const std::vector &inputs) { inputs_ = inputs; } + void set_inputs(const std::vector &inputs); void add_input_value(const ValuePtr &input_value, const std::string &id) { inputs_value_.push_back(std::make_pair(input_value, id)); @@ -282,17 +289,27 @@ class CNode : public AnfNode { return iter == attrs_.cend() ? nullptr : iter->second; } bool HasAttr(const std::string &name) const { return attrs_.find(name) != attrs_.cend(); } + ssize_t input_tensor_num() const { return input_tensor_num_; } + void set_input_tensor_num(ssize_t input_tensor_num) { input_tensor_num_ = input_tensor_num; } + + // Is effect have been handled. + bool IsEffectHandled() const { return effect_handled_; } + + // Set effect handled or not. + void SetEffectHandled(bool handled) { effect_handled_ = handled; } private: std::vector inputs_; VarPtr func_graph_as_var_; bool stop_gradient_; bool in_forward_flag_ = false; + bool effect_handled_ = false; // inputs_value_ store cnode input value and id in pynative mode // output_value_ store cnode value and id in pynative mode std::vector> inputs_value_; std::pair output_value_; std::unordered_map attrs_; + ssize_t input_tensor_num_ = -1; }; // ANode represents the atomic node. It's derived Parameter and ValueNode. @@ -344,10 +361,10 @@ class Parameter : public ANode { return shared_from_this() == other.shared_from_this(); } - void set_used_by_real_kernel() { is_real_kernel_used_ = false; } + void set_used_by_real_kernel(bool used) { is_real_kernel_used_ = used; } bool is_used_by_real_kernel() { return is_real_kernel_used_; } - void set_used_by_dynamic_kernel() { is_used_by_dynamic_kernel_ = true; } + void set_used_by_dynamic_kernel(bool used) { is_used_by_dynamic_kernel_ = used; } bool is_used_by_dynamic_kernel() { return is_used_by_dynamic_kernel_; } private: @@ -469,6 +486,9 @@ static S GetValue(const ValuePtr &value) { std::string GetCNodeFuncName(CNodePtr cnode); +// used to get FuncGraphPtr from a cnode first input +FuncGraphPtr GetCNodeFuncGraph(const AnfNodePtr &node); + // used to check whether an AnfNode is a cnode with a kind of Primitive as first input bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr); @@ -478,6 +498,38 @@ PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node); // used to check whether an AnfNode is a valuenode having some Primitive value bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value); +// Check whether two primitives are same. +bool IsPrimitiveEquals(const PrimitivePtr &prim1, const PrimitivePtr &prim2); + +// Get number of AbstractMonad +size_t GetAbstractMonadNum(const AbstractBasePtrList &args); + +// Check whether the given node has monad abstract. +bool HasAbstractMonad(const AnfNodePtr &node); + +// Check whether the given node has U monad abstract. +bool HasAbstractUMonad(const AnfNodePtr &node); + +// Check whether the given node has IO monad abstract. +bool HasAbstractIOMonad(const AnfNodePtr &node); + +// Gets primitive attribute value as a bool flag. +bool GetPrimitiveFlag(const PrimitivePtr &prim, const std::string &attr); + +// Gets effect info from a primitive by its attributes. +EffectInfo GetPrimEffectInfo(const PrimitivePtr &prim); + +struct MonadState { + AnfNodePtr u{nullptr}; + AnfNodePtr io{nullptr}; +}; + +// Get Memory/IO monad state from node. +MonadState GetMonadState(const AnfNodePtr &node, const AnfNodePtr &skip_input = nullptr); + +// Check if two state is equivalent. +bool IsStateEquivalent(const MonadState &state1, const MonadState &state2); + // used to check whether a ValueNode has some kind of value template static bool IsValueNode(const AnfNodePtr &node) { diff --git a/mindspore/core/ir/dtype.h b/mindspore/core/ir/dtype.h index 3ef58b4695..655de0d715 100644 --- a/mindspore/core/ir/dtype.h +++ b/mindspore/core/ir/dtype.h @@ -37,6 +37,7 @@ #include "ir/dtype/empty.h" #include "ir/dtype/tensor_type.h" #include "ir/dtype/ref.h" +#include "ir/dtype/monad_type.h" /* namespace to support intermediate representation definition */ namespace mindspore { diff --git a/mindspore/ccsrc/frontend/optimizer/control_depend.h b/mindspore/core/ir/dtype/monad_type.cc similarity index 58% rename from mindspore/ccsrc/frontend/optimizer/control_depend.h rename to mindspore/core/ir/dtype/monad_type.cc index 60a00e5b51..fda3ab9574 100644 --- a/mindspore/ccsrc/frontend/optimizer/control_depend.h +++ b/mindspore/core/ir/dtype/monad_type.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,15 +14,9 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CONTROL_DEPEND_H_ -#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CONTROL_DEPEND_H_ - -#include "ir/anf.h" +#include "ir/dtype/monad_type.h" namespace mindspore { -namespace opt { -// Automatically adding control depend based on effect order and side effect analysis. -void AddControlDepend(const FuncGraphPtr &graph); -} // namespace opt +const TypePtr kUMonadType = std::make_shared(); +const TypePtr kIOMonadType = std::make_shared(); } // namespace mindspore -#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_CONTROL_DEPEND_H_ diff --git a/mindspore/core/ir/dtype/monad_type.h b/mindspore/core/ir/dtype/monad_type.h new file mode 100644 index 0000000000..c76b583b27 --- /dev/null +++ b/mindspore/core/ir/dtype/monad_type.h @@ -0,0 +1,70 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_IR_DTYPE_MONAD_H_ +#define MINDSPORE_CORE_IR_DTYPE_MONAD_H_ + +#include +#include + +#include "base/base.h" +#include "ir/dtype/type.h" + +namespace mindspore { +class MonadType : public Object { + public: + ~MonadType() override = default; + MS_DECLARE_PARENT(MonadType, Object) + + TypeId generic_type_id() const override { return kObjectTypeMonad; } + TypePtr DeepCopy() const override = 0; + + protected: + explicit MonadType(const TypeId type_id) : Object(type_id) {} +}; +using MonadTypePtr = std::shared_ptr; + +// UniversalMonad type +class UMonadType : public MonadType { + public: + UMonadType() : MonadType(kObjectTypeUMonad) {} + ~UMonadType() override {} + MS_DECLARE_PARENT(UMonadType, MonadType) + + TypeId generic_type_id() const override { return kObjectTypeUMonad; } + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string ToString() const override { return "UMonad"; } +}; +using UMonadTypePtr = std::shared_ptr; + +// IOMonad type +class IOMonadType : public MonadType { + public: + IOMonadType() : MonadType(kObjectTypeIOMonad) {} + ~IOMonadType() override {} + MS_DECLARE_PARENT(IOMonadType, MonadType) + + TypeId generic_type_id() const override { return kObjectTypeIOMonad; } + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string ToString() const override { return "IOMonad"; } +}; +using IOMonadTypePtr = std::shared_ptr; + +extern const TypePtr kIOMonadType; +extern const TypePtr kUMonadType; +} // namespace mindspore + +#endif // MINDSPORE_CORE_IR_DTYPE_MONDA_H_ diff --git a/mindspore/core/ir/dtype/type.cc b/mindspore/core/ir/dtype/type.cc index ab8d4941f1..e0cf63225a 100644 --- a/mindspore/core/ir/dtype/type.cc +++ b/mindspore/core/ir/dtype/type.cc @@ -137,6 +137,12 @@ const char *ObjectIdLabel(const TypeId &v) { return "kObjectTypeRefKey"; case kObjectTypeRef: return "kObjectTypeRef"; + case kObjectTypeMonad: + return "kObjectTypeMonad"; + case kObjectTypeUMonad: + return "kObjectTypeUMonad"; + case kObjectTypeIOMonad: + return "kObjectTypeIOMonad"; default: return "[Unknown Type Id]"; } @@ -185,6 +191,9 @@ const char *TypeIdLabel(const TypeId &v) { } else { if (v < kObjectTypeEnd) { return ObjectIdLabel(v); + } else if (v > kMonadTypeBegin && v < kMonadTypeEnd) { + // Monad Types is ObjectType + return ObjectIdLabel(v); } else { return NumberIdLabel(v); } diff --git a/mindspore/core/ir/dtype/type_id.h b/mindspore/core/ir/dtype/type_id.h index 7933346157..46209b8ba4 100644 --- a/mindspore/core/ir/dtype/type_id.h +++ b/mindspore/core/ir/dtype/type_id.h @@ -79,7 +79,17 @@ enum TypeId : int { kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeComplex64, - kNumberTypeEnd + kNumberTypeEnd, + // + // Monad Types + // + // Monad types is placed at the end of enum, + // in order to keep fit with the type of existing model on the lite side. + kMonadTypeBegin = kNumberTypeEnd, + kObjectTypeMonad, + kObjectTypeUMonad, + kObjectTypeIOMonad, + kMonadTypeEnd }; } // namespace mindspore #endif // MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_ diff --git a/mindspore/core/ir/dtype_extends.cc b/mindspore/core/ir/dtype_extends.cc index d08d92a3cd..66ba092503 100644 --- a/mindspore/core/ir/dtype_extends.cc +++ b/mindspore/core/ir/dtype_extends.cc @@ -123,7 +123,12 @@ TypePtr TypeIdToType(TypeId id) { return kKeyword; case kObjectTypeTensorType: return kTensorType; + case kObjectTypeUMonad: + return kUMonadType; + case kObjectTypeIOMonad: + return kIOMonadType; case kTypeUnknown: + case kMetaTypeProblem: return kTypeNone; default: MS_LOG(EXCEPTION) << "Not support the type: " << id; @@ -390,6 +395,10 @@ TypePtr StringToType(const std::string &type_name) { type = FunctionStrToType(type_name); } else if (type_name == "mstype") { type = std::make_shared(); + } else if (type_name == "UMonad") { + type = kUMonadType; + } else if (type_name == "IOMonad") { + type = kIOMonadType; } else { // - unsupported to convert // Class diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index 8387d5b6a8..d77553aef8 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -119,18 +119,41 @@ ValuePtr FuncGraph::get_attr(const std::string &key) { } CNodePtr FuncGraph::NewCNode(const std::vector &inputs) { - CNodePtr cnode = std::make_shared(inputs, shared_from_base()); - if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { - order_.push_back(cnode); - MS_LOG(INFO) << "Graph: " << ToString() << ", push back " << cnode->DebugString() << " in order."; - } + return std::make_shared(inputs, shared_from_base()); +} + +CNodePtr FuncGraph::NewCNodeInOrder(const std::vector &inputs) { + CNodePtr cnode = NewCNode(inputs); + order_.push_back(cnode); + return cnode; +} + +CNodePtr FuncGraph::NewCNodeInFront(const std::vector &inputs) { + CNodePtr cnode = NewCNode(inputs); + order_.push_front(cnode); return cnode; } -CNodePtr FuncGraph::NewCNodeWithScope(const std::vector &inputs, const ScopePtr &scope) { - CNodePtr app = NewCNode(inputs); - app->set_scope(scope); - return app; +CNodePtr FuncGraph::NewCNodeBefore(const AnfNodePtr &position, const std::vector &inputs) { + CNodePtr cnode = NewCNode(inputs); + auto iter = std::find(order_.begin(), order_.end(), position); + order_.insert(iter, cnode); + return cnode; +} + +CNodePtr FuncGraph::NewCNodeAfter(const AnfNodePtr &position, const std::vector &inputs) { + CNodePtr cnode = NewCNode(inputs); + if (!position->isa()) { + order_.push_front(cnode); + return cnode; + } + auto iter = std::find(order_.begin(), order_.end(), position); + if (iter == order_.end()) { + order_.push_front(cnode); + return cnode; + } + order_.insert(std::next(iter), cnode); + return cnode; } void FuncGraph::DumpCNodeList() { @@ -557,90 +580,129 @@ AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) { return nullptr; } -void FuncGraph::add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); } - std::list FuncGraph::GetOrderedCnodes() { - if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { - MS_LOG(DEBUG) << "Return ordered cnodes."; - return order_; - } else { - auto this_ptr = shared_from_base(); - auto BelongSameGraph = std::bind(IncludeBelongGraph, this_ptr, std::placeholders::_1); - auto SuccDepends = std::bind(SuccIncludeFV, this_ptr, std::placeholders::_1); - - std::list cnodes; - auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph); - for (const auto &node : nodes) { - auto cnode = dyn_cast(node); - if (cnode) { - cnodes.push_back(cnode); - } + auto this_ptr = shared_from_base(); + auto BelongSameGraph = std::bind(IncludeBelongGraph, this_ptr, std::placeholders::_1); + auto SuccDepends = std::bind(SuccIncludeFV, this_ptr, std::placeholders::_1); + + std::list cnodes; + auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph); + for (const auto &node : nodes) { + auto cnode = dyn_cast(node); + if (cnode) { + cnodes.push_back(cnode); } - return cnodes; } + return cnodes; } void FuncGraph::EraseUnusedNodeInOrder() { - if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { - auto mng = manager_.lock(); - if (mng) { - auto &all_nodes = nodes(); - // Erase unused cnode. - for (auto it = order_.begin(); it != order_.end();) { - if (all_nodes.count(*it)) { - (void)it++; - } else { - MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order."; - it = order_.erase(it); - } + auto mng = manager_.lock(); + if (mng) { + auto &all_nodes = nodes(); + // Erase unused cnode. + for (auto it = order_.begin(); it != order_.end();) { + if (!all_nodes.contains(*it)) { + MS_LOG(DEBUG) << "Remove node " << (*it)->ToString() << " in graph " << ToString() << " order."; + it = order_.erase(it); + continue; } + (void)it++; } } } -void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &n) { - if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa()) { - order_.remove(n->cast()); - MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list."; +void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &node) { + if (node) { + auto cnode = node->cast(); + if (cnode) { + order_.remove(cnode); + MS_LOG(DEBUG) << "Remove the node" << node->DebugString() << " from order list."; + } } } -void FuncGraph::CheckOrder() { - if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { - MS_LOG(DEBUG) << "Check graph " << ToString(); - for (auto it = order_.begin(); it != order_.end(); (void)it++) { - for (const auto &input_node : (*it)->inputs()) { - if (input_node && input_node->isa() && input_node->func_graph() == shared_from_base()) { - // Need to reorder the wrong order node. - auto found = std::find(order_.begin(), it, input_node); - if (found == it) { - DumpCNodeList(); - MS_LOG(EXCEPTION) << "The cnode " << (*it)->DebugString() << " order in " << ToString() - << " doesn't obey the input dependency, " - << "as input " << input_node->DebugString() << " is not ahead of itself."; - } - } - } - } - auto mng = manager_.lock(); - if (mng != nullptr) { - const auto &all_nodes = nodes(); - if (all_nodes.size() != (order_.size() + parameters_.size())) { - DumpCNodeList(); - MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " - << all_nodes.size() - parameters_.size() << "."; - } - } - MS_LOG(DEBUG) << "Check order okay."; +// Maintain cnode order list when a cnode is replaced by a new one. +void FuncGraph::ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { + MS_EXCEPTION_IF_NULL(old_node); + MS_EXCEPTION_IF_NULL(new_node); + if (order_.empty()) { + // Skip if order list is empty. + return; + } + auto old_cnode = old_node->cast(); + if (old_cnode == nullptr) { + // Skip if old node is not cnode, since order list contains cnode only. + return; + } + // Search old node in order list. + auto iter = std::find(order_.begin(), order_.end(), old_cnode); + if (iter == order_.end()) { + // Skip if old node not found in order list. + return; + } + auto new_cnode = new_node->cast(); + if (new_cnode != nullptr) { + // Insert new node just before the old node. + order_.insert(iter, new_cnode); + } + // Remove old node from order list. + // Unused children nodes can be cleared by EraseUnusedNodeInOrder(). + order_.erase(iter); + // Replace isolate node if it is. + ReplaceIsolateNode(old_node, new_node); +} + +void FuncGraph::ReplaceIsolateNode(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { + if (isolate_nodes_.erase(old_node) == 0) { + // Skip if old node is not an isloate node. + return; + } + if (!new_node->isa()) { + // Isolate node can not replaced by a non-cnode. + LOG(WARNING) << "Try replace isolate node: " << old_node->DebugString() << " with: " << new_node->DebugString(); + return; + } + // Replace old node with the new one. + isolate_nodes_.insert(new_node); + // Replace isloate node in manager. + auto graph_manager = manager(); + if (graph_manager != nullptr) { + graph_manager->ReplaceIsolateNode(old_node, new_node); } } + +const std::vector FuncGraph::GetIsolateNodesInOrder() const { + if (isolate_nodes_.empty()) { + return {}; + } + if (isolate_nodes_.size() == 1) { + return std::vector(isolate_nodes_.cbegin(), isolate_nodes_.cend()); + } + std::vector ordered_isolate_nodes; + std::copy_if(order_.cbegin(), order_.cend(), std::back_inserter(ordered_isolate_nodes), + [&](const auto &node) { return isolate_nodes_.find(node) != isolate_nodes_.end(); }); + return ordered_isolate_nodes; +} + +static std::vector MakeInputNodes(const PrimitivePtr &primitive, const std::vector &inputs) { + std::vector input_node_list; + input_node_list.reserve(inputs.size() + 1); + input_node_list.emplace_back(std::make_shared(primitive)); + input_node_list.insert(input_node_list.end(), inputs.begin(), inputs.end()); + return input_node_list; +} + CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const std::vector &inputs) { - auto primitive_node = std::make_shared(primitive); - std::vector input_node_list = {primitive_node}; - std::copy(inputs.begin(), inputs.end(), std::back_inserter(input_node_list)); + auto input_node_list = MakeInputNodes(primitive, inputs); return NewCNode(input_node_list); } +CNodePtr FuncGraph::NewCNodeInOrder(const PrimitivePtr &primitive, const std::vector &inputs) { + auto input_node_list = MakeInputNodes(primitive, inputs); + return NewCNodeInOrder(input_node_list); +} + ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) { auto parameter = add_parameter(); parameter->set_default_param(MakeValue(meta_tensor)); @@ -667,5 +729,4 @@ size_t NewFgSeenGeneration() { } const PrimitivePtr FuncGraphTransform::func_graph_prim_ = std::make_shared("FuncGraph"); -const char kFuncGraphFlagUndetermined[] = "Undeterminate"; } // namespace mindspore diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index 3cfcc624b5..1ce01e2cc1 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -19,6 +19,7 @@ #ifndef MINDSPORE_CORE_IR_FUNC_GRAPH_H_ #define MINDSPORE_CORE_IR_FUNC_GRAPH_H_ +#include #include #include #include @@ -27,12 +28,14 @@ #include #include #include +#include #include "ir/anf.h" #include "ir/manager.h" #include "utils/ordered_set.h" #include "utils/ordered_map.h" #include "base/base_ref.h" +#include "base/effect_info.h" #include "ir/func_graph_cloner.h" #include "abstract/abstract_value.h" @@ -80,6 +83,10 @@ const char FUNC_GRAPH_FLAG_CORE[] = "core"; const char FUNC_GRAPH_ATTR_GRAPH_KERNEL[] = "graph_kernel"; const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param"; +const char kFuncGraphFlagUndetermined[] = "Undeterminate"; +const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry"; +const char kFuncGraphFlagReAutoMonad[] = "ReAutoMonad"; + namespace abstract { class AbstractKeywordArg; using AbstractKeywordArgPtr = std::shared_ptr; @@ -141,9 +148,7 @@ class FuncGraphBase : public Value { MS_DECLARE_PARENT(FuncGraphBase, Value); }; -extern const char kFuncGraphFlagUndetermined[]; - -class FuncGraph : public FuncGraphBase { +class FuncGraph : public FuncGraphBase, public EffectInfoHolder { public: FuncGraph(); using Drawer = std::function; @@ -169,11 +174,21 @@ class FuncGraph : public FuncGraphBase { // create a cnode with given inputs, bound to this graph virtual CNodePtr NewCNode(const std::vector &inputs = std::vector()); - - // create a cnode with given inputs, bound to this graph, and set to specific scope - CNodePtr NewCNodeWithScope(const std::vector &inputs, const ScopePtr &scope); virtual CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector &prim_inputs); + // create a cnode with given inputs, bound to this graph and push back to order list. + CNodePtr NewCNodeInOrder(const std::vector &inputs = std::vector()); + CNodePtr NewCNodeInOrder(const PrimitivePtr &primitive, const std::vector &prim_inputs); + + // create a cnode with given inputs, bound to this graph and push back to front of order list. + CNodePtr NewCNodeInFront(const std::vector &inputs = std::vector()); + + // create a cnode with given inputs, put it to order list before the position node. + CNodePtr NewCNodeBefore(const AnfNodePtr &position, const std::vector &inputs); + + // create a cnode with given inputs, put it to order list after the position node. + CNodePtr NewCNodeAfter(const AnfNodePtr &position, const std::vector &inputs); + virtual ParameterPtr add_weight(const tensor::MetaTensorPtr &meta_tensor); // Functions for handling variable argument, keyword-only arguments and variable keyword argument AnfNodePtr GetDefaultValueByName(const std::string &name); @@ -330,8 +345,8 @@ class FuncGraph : public FuncGraphBase { const std::vector &specialized_parameter_list, std::unordered_map *repl_nodes); - const std::vector ¶mter_obj_nodes() const { return paramter_obj_nodes_; } - void add_parameter_obj_node(const AnfNodePtr &p); + const std::vector &used_global_parameters() const { return used_global_parameters_; } + void add_used_global_parameters(const AnfNodePtr &p) { used_global_parameters_.push_back(p); } std::unordered_map attrs_; std::vector joined_shapes_; @@ -343,11 +358,37 @@ class FuncGraph : public FuncGraphBase { std::list GetOrderedCnodes(); void EraseUnusedNodeInOrder(const AnfNodePtr &n); void EraseUnusedNodeInOrder(); - void CheckOrder(); void DumpCNodeList(); - void ReleaseFullOrderToEffectOrder(); - void SetEffectDepends(const std::vector &depend_inputs); - bool HasEffect(const CNodePtr &cnode); + const std::list &order_list() const { return order_; } + + void set_order_list(std::list &&order_list) { order_ = std::move(order_list); } + + // Add a cnode at the end of order list. + void AppendOrderList(const CNodePtr &cnode) { order_.push_back(cnode); } + + // Prepend cnode at the front of order list. + void PrependOrderList(const CNodePtr &cnode) { order_.insert(order_.begin(), cnode); } + + // Maintain cnode order list when a cnode is replaced by a new one. + void ReplaceInOrder(const AnfNodePtr &old_node, const AnfNodePtr &new_node); + + // Clear cnode order list. + void ClearOrderList() { order_.clear(); } + + // Gets nodes that not related to output, e.g. side-effect calls. + const std::set &isolate_nodes() const { return isolate_nodes_; } + + // Add an isolate node. + void AddIsolateNode(const AnfNodePtr &node) { isolate_nodes_.insert(node); } + + // Replace an isolate node. + void ReplaceIsolateNode(const AnfNodePtr &old_node, const AnfNodePtr &new_node); + + // Clear isolate nodes. + void ClearIsolateNodes() { isolate_nodes_.clear(); } + + // Get isolate nodes with order as OrderList. + const std::vector GetIsolateNodesInOrder() const; bool stub() const { return stub_; } void set_stub(bool stub) { stub_ = stub; } @@ -382,7 +423,12 @@ class FuncGraph : public FuncGraphBase { // parameters of this function std::vector parameters_; - std::vector paramter_obj_nodes_; + + // global parameters used by this function. + std::vector used_global_parameters_; + + // isolate nodes, i.e. nodes that not related to output. + std::set isolate_nodes_; // whether there is a *args and **kwargs, and count kwonlyargs'number bool has_vararg_; diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc index 52f883dcf2..a1cc9b8531 100644 --- a/mindspore/core/ir/func_graph_cloner.cc +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -93,6 +93,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { new_node->set_attrs(old_node->attrs()); ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); new_node->set_scope(scope); + new_node->CloneUserData(old_node); if (IsParallelConsiderCNode(old_node) && new_node->scope() == kDefaultScope) { new_node->set_fullname_with_scope(old_node->fullname_with_scope()); } @@ -469,6 +470,42 @@ void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &t for (auto &node : nodes) { CloneNode(node, target_func_graph); } + // Only func_graph is inlined, it cannot be found in repl; + if (repl_func_graph_.find(func_graph) != repl_func_graph_.end()) { + CloneOrderList(func_graph, target_func_graph); + CloneIsolateNodes(func_graph, target_func_graph); + } +} + +void Cloner::CloneOrderList(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { + for (auto &cnode : func_graph->order_list()) { + auto it = repl_node_.find(cnode); + if (it == repl_node_.end()) { + // For cnode which generated in Analyze phase, it cannot got from nodes API of func_graph, + // so it cannot be cloned in normal Clone API. + // If we ignore it, the order will be lost. + // Therefore we put this old node as placeholder to the order list of target func_graph to + // keep the order. + // It may be replaced in ProgramSpecialize. + // If this disconnected node is not used in target func_graph, it will be cleared after + // ProgramSpecialize; + target_func_graph->AppendOrderList(cnode); + continue; + } + auto repl_cnode = dyn_cast(it->second); + if (repl_cnode) { + target_func_graph->AppendOrderList(repl_cnode); + } + } +} + +void Cloner::CloneIsolateNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { + for (auto &node : func_graph->isolate_nodes()) { + auto it = repl_node_.find(node); + if (it != repl_node_.end()) { + target_func_graph->AddIsolateNode(it->second); + } + } } void Cloner::Run() { diff --git a/mindspore/core/ir/func_graph_cloner.h b/mindspore/core/ir/func_graph_cloner.h index 6d75c8d13c..6cd5387078 100644 --- a/mindspore/core/ir/func_graph_cloner.h +++ b/mindspore/core/ir/func_graph_cloner.h @@ -83,6 +83,8 @@ class Cloner { void AddTotalGraphs(const FuncGraphPtr &func_graph); bool CheckStatus(const FuncGraphPtr &func_graph, bool is_inline); void CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); + void CloneOrderList(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); + void CloneIsolateNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); void CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); void CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); void InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms); diff --git a/mindspore/core/ir/func_graph_extends.cc b/mindspore/core/ir/func_graph_extends.cc index be4c8cf707..3bf1a463be 100644 --- a/mindspore/core/ir/func_graph_extends.cc +++ b/mindspore/core/ir/func_graph_extends.cc @@ -59,7 +59,7 @@ void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) { if (force_new_ret || return_ == nullptr) { std::vector params({NewValueNode(prim::kPrimReturn), value}); FuncGraphPtr this_graph = shared_from_base(); - return_ = this_graph->NewCNode(params); + return_ = this_graph->NewCNodeInOrder(params); } else { if (manager_.lock()) { manager_.lock()->SetEdge(return_, 1, value); @@ -131,7 +131,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, std::string kw_param_name = kwarg->get_key(); MS_EXCEPTION_IF_NULL(specialized_graph); AnfNodePtr param_node = specialized_graph->GetParameterByName(kw_param_name); - // if not find correspoding parameter node + // if not find corresponding parameter node if (param_node == nullptr) { if (!has_kwarg()) { MS_LOG(EXCEPTION) << "Got unexpected keyword argument: " << kw_param_name; @@ -296,29 +296,6 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) return specialized_graph; } -const char kPrimHasEffect[] = "_side_effect_flag"; - -bool FuncGraph::HasEffect(const CNodePtr &cnode) { - auto prim = GetCNodePrimitive(cnode); - if (prim != nullptr && prim->isa()) { - auto do_sig = prim->cast(); - auto prim_val = do_sig->function(); - if (prim_val != nullptr && prim_val->isa()) { - prim = prim_val->cast(); - } else { - prim = nullptr; - } - } - if (prim != nullptr) { - auto effect_val = prim->GetAttr(kPrimHasEffect); - if (effect_val && effect_val->isa()) { - auto effect_bool = GetValue(effect_val); - return effect_bool; - } - } - return false; -} - std::shared_ptr> FindRoots(const std::vector &segment) { std::shared_ptr> roots = std::make_shared>(segment); for (const auto &node : segment) { @@ -364,62 +341,4 @@ std::shared_ptr> FindLeaves(const std::vector &se } return nodes; } - -void FuncGraph::ReleaseFullOrderToEffectOrder() { - MS_LOG(DEBUG) << "Flag has_effect " << has_flag(GRAPH_FLAG_HAS_EFFECT) << "."; - if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { - std::list depends_order; - std::vector segment; - for (const auto &cnode : order_) { - if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) { - continue; - } - if (HasEffect(cnode)) { - MS_LOG(DEBUG) << "Meet a effect node " << cnode->DebugString() << "."; - if (segment.size() > 0) { - auto roots = FindRoots(segment); - for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) { - depends_order.push_back(*iter); - } - } - segment.clear(); - depends_order.push_back(cnode); - } else { - MS_LOG(DEBUG) << "Meet a general node " << cnode->DebugString() << "."; - segment.push_back(cnode); - } - } - if (segment.size() > 1) { - auto roots = FindRoots(segment); - for (auto iter = roots->begin(); iter != roots->end(); (void)iter++) { - depends_order.push_back(*iter); - } - } - std::vector depend_inputs; - auto old_ret = output(); - for (auto iter = depends_order.rbegin(); iter != depends_order.rend(); (void)iter++) { - if (*iter != old_ret) { - depend_inputs.push_back(*iter); - } - } - set_flag(GRAPH_FLAG_HAS_EFFECT, false); - set_flag(GRAPH_FLAG_EFFECT_PATIAL_ORDER, true); - if (!depend_inputs.empty()) { - SetEffectDepends(depend_inputs); - } - } -} - -void FuncGraph::SetEffectDepends(const std::vector &depend_inputs) { - auto old_ret = output(); - std::vector inputs{NewValueNode(prim::kPrimDepend), old_ret}; - (void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end()); - auto new_ret = NewCNode(inputs); - auto mng = manager(); - if (mng) { - (void)mng->Replace(old_ret, new_ret); - } else { - return_->set_input(1, new_ret); - } -} } // namespace mindspore diff --git a/mindspore/core/ir/graph_utils.cc b/mindspore/core/ir/graph_utils.cc index b0dc47bc2b..e7df960a0e 100644 --- a/mindspore/core/ir/graph_utils.cc +++ b/mindspore/core/ir/graph_utils.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -92,10 +93,10 @@ std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, c } // search the cnodes inside this graph only -std::vector BroadFirstSearchGraphCNodes(CNodePtr ret) { +std::vector BroadFirstSearchGraphCNodes(const std::vector &starts) { std::deque todo(1024); todo.clear(); - todo.push_back(ret); + todo.insert(todo.end(), starts.begin(), starts.end()); std::vector sorted_nodes; auto seen = NewSeenGeneration(); while (!todo.empty()) { @@ -117,6 +118,33 @@ std::vector BroadFirstSearchGraphCNodes(CNodePtr ret) { return sorted_nodes; } +// search the cnode match the predicate inside this graph only +CNodePtr BroadFirstSearchFirstOf(const std::vector &starts, const MatchFunc &match_predicate) { + std::deque todo(1024); + todo.clear(); + todo.insert(todo.end(), starts.begin(), starts.end()); + auto seen = NewSeenGeneration(); + while (!todo.empty()) { + CNodePtr top = todo.front(); + todo.pop_front(); + if (match_predicate(top)) { + return top; + } + auto inputs = top->inputs(); + for (auto &item : inputs) { + if (item->seen_ == seen) { + continue; + } + + if (item->isa()) { + todo.push_back(item->cast()); + } + item->seen_ = seen; + } + } + return nullptr; +} + std::vector BroadFirstSearchGraphUsed(FuncGraphPtr root) { std::deque todo; todo.push_back(root); @@ -138,6 +166,14 @@ std::vector BroadFirstSearchGraphUsed(FuncGraphPtr root) { return sorted; } +// PushSuccessors push cnode inputs to a vector as successors for topo sort. +static void PushSuccessors(const CNodePtr &cnode, std::vector *vecs) { + auto &inputs = cnode->inputs(); + vecs->reserve(vecs->size() + inputs.size()); + // To keep evaluate order from left to right, we push inputs in reversed order. + vecs->insert(vecs->end(), inputs.rbegin(), inputs.rend()); +} + std::vector SuccDeeper(const AnfNodePtr &node) { std::vector vecs; if (node == nullptr) { @@ -153,8 +189,7 @@ std::vector SuccDeeper(const AnfNodePtr &node) { return vecs; } else if (node->func_graph() != nullptr) { if (node->isa()) { - auto &inputs = node->cast()->inputs(); - (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); + PushSuccessors(node->cast(), &vecs); } return vecs; } @@ -177,8 +212,7 @@ std::vector SuccDeeperSimple(const AnfNodePtr &node) { return vecs; } else { if (node->isa()) { - auto &inputs = node->cast()->inputs(); - (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); + PushSuccessors(node->cast(), &vecs); } return vecs; } @@ -186,13 +220,9 @@ std::vector SuccDeeperSimple(const AnfNodePtr &node) { std::vector SuccIncoming(const AnfNodePtr &node) { std::vector vecs; - if (node == nullptr) { - return vecs; - } - - if (node->isa()) { - auto &inputs = node->cast()->inputs(); - (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); + auto cnode = dyn_cast(node); + if (cnode != nullptr) { + PushSuccessors(cnode, &vecs); } return vecs; } @@ -216,11 +246,20 @@ std::vector SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr & } } } - (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); + PushSuccessors(cnode, &vecs); } return vecs; } +const std::vector &GetInputs(const AnfNodePtr &node) { + static std::vector empty_inputs; + auto cnode = dyn_cast(node); + if (cnode != nullptr) { + return cnode->inputs(); + } + return empty_inputs; +} + IncludeType AlwaysInclude(const AnfNodePtr &) { return FOLLOW; } IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node) { diff --git a/mindspore/core/ir/graph_utils.h b/mindspore/core/ir/graph_utils.h index b1b42fea6d..b244ccb902 100644 --- a/mindspore/core/ir/graph_utils.h +++ b/mindspore/core/ir/graph_utils.h @@ -41,6 +41,7 @@ using IncludeFunc = std::function; using FilterFunc = std::function; using SuccFunc = std::function(AnfNodePtr)>; using SearchFunc = std::function(const AnfNodePtr &, const IncludeFunc &)>; +using MatchFunc = std::function; std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include); std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include); @@ -51,6 +52,8 @@ std::vector SuccDeeperSimple(const AnfNodePtr &node); std::vector SuccIncoming(const AnfNodePtr &node); std::vector SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node); +const std::vector &GetInputs(const AnfNodePtr &node); + IncludeType AlwaysInclude(const AnfNodePtr &node); IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node); @@ -68,8 +71,11 @@ std::vector DeepUsersSearch(const AnfNodePtr &root, const IncludeFun std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming, const IncludeFunc &include = AlwaysInclude); -std::vector BroadFirstSearchGraphCNodes(CNodePtr ret); +std::vector BroadFirstSearchGraphCNodes(const std::vector &starts); std::vector BroadFirstSearchGraphUsed(FuncGraphPtr root); + +CNodePtr BroadFirstSearchFirstOf(const std::vector &starts, const MatchFunc &match_predicate); + class FuncGraphIndex { public: explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, diff --git a/mindspore/core/ir/graph_utils_extends.cc b/mindspore/core/ir/graph_utils_extends.cc index 1662e6111f..ffbc822bdd 100644 --- a/mindspore/core/ir/graph_utils_extends.cc +++ b/mindspore/core/ir/graph_utils_extends.cc @@ -42,11 +42,11 @@ class DeepFirstSearcher : public AnfIrVisitor { std::vector Search(const AnfNodePtr &root) { if (root == nullptr) { - return res_; + return std::move(res_); } seen_ = NewSeenGeneration(); Visit(root); - return res_; + return std::move(res_); } void Visit(const AnfNodePtr &node) override { diff --git a/mindspore/core/ir/manager.cc b/mindspore/core/ir/manager.cc index 6ac91fcce1..fe3266a489 100644 --- a/mindspore/core/ir/manager.cc +++ b/mindspore/core/ir/manager.cc @@ -197,11 +197,18 @@ void FuncGraphManager::AddFuncGraph(FuncGraphPtr func_graph, bool is_root) { if (func_graphs_.contains(func_graph)) { return; } + + // New nodes to be acquired. + std::vector new_nodes = func_graph->parameters(); + new_nodes.emplace_back(func_graph->get_return()); + auto &isolate_nodes = func_graph->isolate_nodes(); + new_nodes.insert(new_nodes.end(), isolate_nodes.begin(), isolate_nodes.end()); + + // Add func_graph as a managed graph. AddIntoManaged(func_graph); - std::vector para = func_graph->parameters(); - AcquireNodes(para); - std::vector return_vec({func_graph->get_return()}); - AcquireNodes(return_vec); + + // Acquire all nodes from func_graph. + AcquireNodes(new_nodes); } // clear the all information in manager @@ -210,6 +217,7 @@ void FuncGraphManager::Clear() { all_nodes_.clear(); node_users_.clear(); roots_.clear(); + isolate_nodes_.clear(); signals_->InvalidateComputer(); } @@ -274,6 +282,8 @@ void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) { FuncGraphManagerPtr this_manager = shared_from_this(); fg->set_manager(this_manager); } + const auto &fg_isolate_nodes = fg->isolate_nodes(); + isolate_nodes_.insert(fg_isolate_nodes.begin(), fg_isolate_nodes.end()); func_graphs_.add(fg); } @@ -433,10 +443,14 @@ void FuncGraphManager::AddParameter(const FuncGraphPtr &fg, const AnfNodePtr &pa } bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { + auto func_graph = old_node->func_graph(); auto tr = Transact(); bool success = tr.Replace(old_node, new_node); if (success) { tr.Commit(); + if (func_graph != nullptr) { + func_graph->ReplaceInOrder(old_node, new_node); + } } return success; } @@ -447,6 +461,12 @@ void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodeP tr.Commit(); } +void FuncGraphManager::AddEdge(const AnfNodePtr &node, const AnfNodePtr &value) { + auto tr = Transact(); + tr.AddEdge(node, value); + tr.Commit(); +} + void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope) { AnfNodePtr source_return = source->get_return(); AnfNodePtr source_output = source->output(); @@ -549,6 +569,13 @@ void FuncGraphManager::ParseChanges(const std::vector &changes, EdgeTupl (*adds)[edge.new_node] += 1; edge.root_node->set_input(edge.index, edge.new_node); } break; + case Change::kTxAddEdge: { + auto edge = args.cast(); + auto index = edge.root_node->inputs().size(); + (*add_edges)[std::make_pair(edge.root_node, std::make_pair(index, edge.new_node))] += 1; + (*adds)[edge.new_node] += 1; + edge.root_node->add_input(edge.new_node); + } break; case Change::kTxSetParams: { auto param = args.cast(); MS_EXCEPTION_IF_NULL(param.func_graph); @@ -614,6 +641,29 @@ void FuncGraphManager::CommitChanges(const std::vector &changes) { MaybeDropFuncGraphs(*drop_func_graphs); } +void FuncGraphManager::ReplaceIsolateNode(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { + MS_EXCEPTION_IF_NULL(old_node); + MS_EXCEPTION_IF_NULL(new_node); + if (isolate_nodes_.erase(old_node) == 0) { + return; + } + if (!new_node->isa()) { + MS_LOG(EXCEPTION) << "Replace isolate node: " << old_node->DebugString() + << " with non-cnode: " << new_node->DebugString(); + } + isolate_nodes_.insert(new_node); +} + +void FuncGraphManager::ClearIsolateNodes() { + // If FuncGraph A has IsolateNode which input is FuncGraph B, B had been add to FuncGraph A's valuenode + // by AddFuncGraph api, so if that isolate node is totoaly unused after AutoMonad, FuncGraph B should + // be removed from FuncGraph A's valuenode, otherwise it will confuse FVTotalComputer. + std::vector isolate_nodes_vec(isolate_nodes_.cbegin(), isolate_nodes_.cend()); + auto drop_func_graphs = MaybeDropNodes(isolate_nodes_vec); + MaybeDropFuncGraphs(*drop_func_graphs); + isolate_nodes_.clear(); +} + void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector ¶ms) { changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); } @@ -650,6 +700,15 @@ void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfN changes_.emplace_back(Change::kTxSetEdge, ArgsOfSetEdge{cnode, v, IntToSize(k)}); } +void FuncGraphTransaction::AddEdge(const AnfNodePtr &src_node, const AnfNodePtr &v) { + MS_EXCEPTION_IF_NULL(src_node); + auto cnode = src_node->cast(); + if (cnode == nullptr) { + MS_LOG(EXCEPTION) << "src_node should be a cnode, but cast failed."; + } + changes_.emplace_back(Change::kTxAddEdge, ArgsOfAddEdge{cnode, v}); +} + void FuncGraphTransaction::Commit() { std::vector changes; changes_.swap(changes); diff --git a/mindspore/core/ir/manager.h b/mindspore/core/ir/manager.h index d961e94ae5..41f2b3b1dd 100644 --- a/mindspore/core/ir/manager.h +++ b/mindspore/core/ir/manager.h @@ -314,6 +314,7 @@ class FuncGraphManager : public std::enable_shared_from_this { void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false); bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value); + void AddEdge(const AnfNodePtr &node, const AnfNodePtr &value); void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope); FuncGraphTransaction Transact(); @@ -350,6 +351,15 @@ class FuncGraphManager : public std::enable_shared_from_this { IncludeType Limit(const AnfNodePtr &node); + // Gets isolate nodes that not related to output, e.g. side-effect calls. + const std::set &isolate_nodes() const { return isolate_nodes_; } + + // Replace node in isolate node list. + void ReplaceIsolateNode(const AnfNodePtr &old_node, const AnfNodePtr &new_node); + + // Clear all isolate nodes. + void ClearIsolateNodes(); + // Static Analysis NodeUsersMap node_users_; AnfNodeSet all_nodes_; // managed nodes @@ -383,6 +393,9 @@ class FuncGraphManager : public std::enable_shared_from_this { std::shared_ptr recursive_; std::shared_ptr j_total_; + // Isolate Nodes + std::set isolate_nodes_; + bool is_manage_; std::function limit_; }; @@ -406,8 +419,10 @@ class FuncGraphTransaction { // replace old_node with new_node bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); - // set esge, i.e., declare setting node.inputs[key] to value. + // set edge, i.e., declare setting node.inputs[key] to value. void SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v); + // Add edge, i.e., append value to node.inputs. + void AddEdge(const AnfNodePtr &src_node, const AnfNodePtr &v); // commit all changes void Commit(); @@ -454,8 +469,20 @@ struct ArgsOfSetEdge { } }; +// args for add edge +struct ArgsOfAddEdge { + CNodePtr root_node; + AnfNodePtr new_node; + bool operator==(const ArgsOfAddEdge &other) const { return &other == this; } + + friend std::ostream &operator<<(std::ostream &os, const ArgsOfAddEdge &other) { + os << "[ArgsOfAddEdge]"; + return os; + } +}; + struct Change { - enum OpName { kTxSetParams, kTxSetEdge, kTxAddParam }; + enum OpName { kTxSetParams, kTxSetEdge, kTxAddParam, kTxAddEdge }; OpName op; Any args; Change(OpName name, const Any ¶) : op(name), args(para) {} diff --git a/mindspore/core/ir/primitive.cc b/mindspore/core/ir/primitive.cc index 993e0bab83..622df78d35 100644 --- a/mindspore/core/ir/primitive.cc +++ b/mindspore/core/ir/primitive.cc @@ -35,6 +35,19 @@ Primitive::Primitive(const std::string &name, const bool is_base, const PrimType is_const_prim_(false), id_(MakeId()) {} +Primitive::Primitive(const std::string &name, const std::unordered_map &attrs) + : Named(name), + is_base_(true), + has_signature_(false), + prim_type_(kPrimTypeBuiltIn), + record_evaluate_add_attr_(false), + is_const_prim_(false), + id_(MakeId()) { + for (auto &attr : attrs) { + attrs_[attr.first] = attr.second; + } +} + Primitive::Primitive(const Primitive &prim) : Named(prim), attrs_(prim.attrs_), diff --git a/mindspore/core/ir/primitive.h b/mindspore/core/ir/primitive.h index 994526122e..b57e883579 100644 --- a/mindspore/core/ir/primitive.h +++ b/mindspore/core/ir/primitive.h @@ -42,6 +42,7 @@ enum PrimType { class Primitive : public Named { public: explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn); + Primitive(const std::string &name, const std::unordered_map &attrs); Primitive(const Primitive &prim); MS_DECLARE_PARENT(Primitive, Named); abstract::AbstractBasePtr ToAbstract(); @@ -147,7 +148,7 @@ struct PrimitiveEqual { bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { MS_EXCEPTION_IF_NULL(t1); MS_EXCEPTION_IF_NULL(t2); - return t1->name() == t2->name(); + return t1 == t2 || t1->name() == t2->name(); } }; diff --git a/mindspore/core/ir/value.cc b/mindspore/core/ir/value.cc index 4e6d1de6a3..0f589fedbe 100644 --- a/mindspore/core/ir/value.cc +++ b/mindspore/core/ir/value.cc @@ -303,4 +303,11 @@ bool ValueDictionary::operator==(const ValueDictionary &other) const { } return true; } + +bool UMonad::operator==(const Value &other) const { return other.isa(); } +const ValuePtr kUMonad = std::make_shared(); + +bool IOMonad::operator==(const Value &other) const { return other.isa(); } +const ValuePtr kIOMonad = std::make_shared(); + } // namespace mindspore diff --git a/mindspore/core/ir/value.h b/mindspore/core/ir/value.h index c01f772232..a3e43d2bc8 100644 --- a/mindspore/core/ir/value.h +++ b/mindspore/core/ir/value.h @@ -253,6 +253,42 @@ class AnyValue : public Value { }; extern const ValuePtr kAnyValue; +class Monad : public Value { + public: + ~Monad() override = default; + MS_DECLARE_PARENT(Monad, Value) + abstract::AbstractBasePtr ToAbstract() override = 0; + + protected: + explicit Monad(TypePtr type) : Value(type) {} +}; + +class UMonad : public Monad { + public: + UMonad() : Monad(kUMonadType) {} + ~UMonad() override = default; + MS_DECLARE_PARENT(UMonad, Monad) + std::size_t hash() const override { return tid(); } + bool operator==(const Value &other) const override; + abstract::AbstractBasePtr ToAbstract() override; + std::string ToString() const override { return "U"; } +}; +using UMonadPtr = std::shared_ptr; +extern const ValuePtr kUMonad; + +class IOMonad : public Monad { + public: + IOMonad() : Monad(kIOMonadType) {} + ~IOMonad() override = default; + MS_DECLARE_PARENT(IOMonad, Monad) + std::size_t hash() const override { return tid(); } + bool operator==(const Value &other) const override; + abstract::AbstractBasePtr ToAbstract() override; + std::string ToString() const override { return "IO"; } +}; +using IOMonadPtr = std::shared_ptr; +extern const ValuePtr kIOMonad; + template <> inline const char *GetValue(const ValuePtr &value) { if (value == nullptr) { diff --git a/mindspore/core/ir/value_extends.cc b/mindspore/core/ir/value_extends.cc index c75da80665..c2eada9bcd 100644 --- a/mindspore/core/ir/value_extends.cc +++ b/mindspore/core/ir/value_extends.cc @@ -82,4 +82,8 @@ abstract::AbstractBasePtr ValueDictionary::ToAbstract() { [](const std::pair &item) { return std::make_pair(item.first, item.second->ToAbstract()); }); return std::make_shared(kv); } + +abstract::AbstractBasePtr UMonad::ToAbstract() { return std::make_shared(); } + +abstract::AbstractBasePtr IOMonad::ToAbstract() { return std::make_shared(); } } // namespace mindspore diff --git a/mindspore/core/load_mindir/anf_model_parser.cc b/mindspore/core/load_mindir/anf_model_parser.cc index f787e3af10..3164c1a85b 100644 --- a/mindspore/core/load_mindir/anf_model_parser.cc +++ b/mindspore/core/load_mindir/anf_model_parser.cc @@ -43,14 +43,13 @@ enum ParseForm : int { FORM_PARSE_SCALAR = 1, FORM_PARSE_TENSOR = 2, FORM_PARSE_NONE = 3, - FORM_PARSE_UNDEFINE = 4, + FORM_PARSE_MONAD = 4, + FORM_PARSE_UNDEFINE = 5, }; -static std::map kParseTypeSwitchMap{{"type", FORM_PARSE_TYPE}, - {"scalar", FORM_PARSE_SCALAR}, - {"tensor", FORM_PARSE_TENSOR}, - {"none", FORM_PARSE_NONE}, - {"", FORM_PARSE_UNDEFINE}}; +static std::map kParseTypeSwitchMap{ + {"type", FORM_PARSE_TYPE}, {"scalar", FORM_PARSE_SCALAR}, {"tensor", FORM_PARSE_TENSOR}, + {"none", FORM_PARSE_NONE}, {"Monad", FORM_PARSE_MONAD}, {"", FORM_PARSE_UNDEFINE}}; static std::unordered_map kDefaultValueSwitchMap{ {mind_ir::TensorProto_DataType_BOOL, kNumberTypeBool}, @@ -574,6 +573,29 @@ bool MSANFModelParser::ObtainValueNodeInNoneForm(const std::string &value_node_n return true; } +bool MSANFModelParser::ObtainValueNodeInMonadForm(const std::string &value_node_name, + const mind_ir::AttributeProto &attr_proto) { + const std::string &ref_attr_name = attr_proto.ref_attr_name(); + if (ref_attr_name.find("UMonad") != std::string::npos) { + const ValuePtr kUMonad = std::make_shared(); + auto monad_abs = kUMonad->ToAbstract(); + auto new_value_node = NewValueNode(kUMonad); + MS_EXCEPTION_IF_NULL(new_value_node); + new_value_node->set_abstract(monad_abs); + anfnode_build_map_[value_node_name] = new_value_node; + } else if (ref_attr_name.find("IOMonad") != std::string::npos) { + const ValuePtr kIOMonad = std::make_shared(); + auto monad_abs = kIOMonad->ToAbstract(); + auto new_value_node = NewValueNode(kIOMonad); + MS_EXCEPTION_IF_NULL(new_value_node); + new_value_node->set_abstract(monad_abs); + anfnode_build_map_[value_node_name] = new_value_node; + } else { + return false; + } + return true; +} + bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto) { if (!attr_proto.has_ref_attr_name()) { @@ -589,6 +611,8 @@ bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_na type = ref_attr_name.substr(pos, string("type:").length() - 1); } else if ((pos = ref_attr_name.find("tensor:")) != std::string::npos) { type = ref_attr_name.substr(pos, string("tensor:").length() - 1); + } else if ((pos = ref_attr_name.find("Monad:")) != std::string::npos) { + type = ref_attr_name.substr(pos, string("Monad:").length() - 1); } else if (ref_attr_name == "none") { type = ref_attr_name; } @@ -620,6 +644,10 @@ bool MSANFModelParser::GetAttrValueForValueNode(const std::string &value_node_na ObtainValueNodeInNoneForm(value_node_name, attr_proto); break; } + case FORM_PARSE_MONAD: { + ObtainValueNodeInMonadForm(value_node_name, attr_proto); + break; + } default: MS_LOG(ERROR) << "parse attr type don't support the ref_attr_name: " << ref_attr_name; return false; @@ -722,11 +750,17 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc MS_EXCEPTION_IF_NULL(cnode_ptr); if (0 == kv.size()) { - AbstractBasePtrList elem; - for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { - elem.push_back(cnode_ptr->input(index)->abstract()); + if (node_type == "UpdateState") { + const ValuePtr kUMonad = std::make_shared(); + auto monad_abs = kUMonad->ToAbstract(); + cnode_ptr->set_abstract(monad_abs); + } else { + AbstractBasePtrList elem; + for (size_t index = 1; index < cnode_ptr->inputs().size(); ++index) { + elem.push_back(cnode_ptr->input(index)->abstract()); + } + cnode_ptr->set_abstract(std::make_shared(elem)); } - cnode_ptr->set_abstract(std::make_shared(elem)); } else if (1 == kv.size()) { std::unordered_map::iterator iter = kv.begin(); cnode_ptr->set_abstract(iter->second); diff --git a/mindspore/core/load_mindir/anf_model_parser.h b/mindspore/core/load_mindir/anf_model_parser.h index 805bb765b1..2ac231b3da 100644 --- a/mindspore/core/load_mindir/anf_model_parser.h +++ b/mindspore/core/load_mindir/anf_model_parser.h @@ -64,6 +64,7 @@ class MSANFModelParser { bool GetAttrValueForValueNode(const std::string &value_node_name, const mind_ir::AttributeProto &attr_tensor); bool ObtainValueNodeInTypeForm(const string &value_node_name, const mind_ir::TensorProto &attr_tensor); bool ObtainValueNodeInNoneForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto); + bool ObtainValueNodeInMonadForm(const std::string &value_node_name, const mind_ir::AttributeProto &attr_proto); std::unordered_map GetAbstractForCNode( const mind_ir::AttributeProto &attr_proto); diff --git a/mindspore/core/utils/flags.cc b/mindspore/core/utils/flags.cc index 671f62d5cf..56bbc4ff5b 100644 --- a/mindspore/core/utils/flags.cc +++ b/mindspore/core/utils/flags.cc @@ -24,6 +24,11 @@ const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order"; const char GRAPH_FLAG_CACHE_ENABLE[] = "cache_enable"; const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect"; const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect"; +const char GRAPH_FLAG_SIDE_EFFECT_IO[] = "side_effect_io"; +const char GRAPH_FLAG_SIDE_EFFECT_MEM[] = "side_effect_mem"; +const char GRAPH_FLAG_SIDE_EFFECT_EXCEPTION[] = "side_effect_exception"; +const char GRAPH_FLAG_SIDE_EFFECT_PROPAGATE[] = "side_effect_propagate"; +const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP[] = "side_effect_backprop"; // method names of python primitive called from c++ source code // 1. infer method name of class 'PrimitiveWithInfer' @@ -41,4 +46,5 @@ const char ATTR_MIN_SHAPE[] = "min_shape"; const char ATTR_MAX_SHAPE[] = "max_shape"; const char ATTR_MIN_VALUE[] = "min_value"; const char ATTR_MAX_VALUE[] = "max_value"; +const char ATTR_NO_BROADEN[] = "no_broaden"; } // namespace mindspore diff --git a/mindspore/core/utils/flags.h b/mindspore/core/utils/flags.h index 8dd9fdfbbf..5a4d8188d4 100644 --- a/mindspore/core/utils/flags.h +++ b/mindspore/core/utils/flags.h @@ -24,6 +24,11 @@ extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[]; extern const char GRAPH_FLAG_CACHE_ENABLE[]; extern const char GRAPH_FLAG_RANDOM_EFFECT[]; extern const char GRAPH_FLAG_SIDE_EFFECT[]; +extern const char GRAPH_FLAG_SIDE_EFFECT_IO[]; +extern const char GRAPH_FLAG_SIDE_EFFECT_MEM[]; +extern const char GRAPH_FLAG_SIDE_EFFECT_EXCEPTION[]; +extern const char GRAPH_FLAG_SIDE_EFFECT_PROPAGATE[]; +extern const char GRAPH_FLAG_SIDE_EFFECT_BACKPROP[]; extern const char PY_PRIM_METHOD_INFER[]; extern const char PY_PRIM_METHOD_CHECK[]; @@ -36,6 +41,7 @@ extern const char ATTR_MIN_SHAPE[]; extern const char ATTR_MAX_SHAPE[]; extern const char ATTR_MIN_VALUE[]; extern const char ATTR_MAX_VALUE[]; +extern const char ATTR_NO_BROADEN[]; } // namespace mindspore #endif // MINDSPORE_CORE_UTILS_FLAGS_H diff --git a/mindspore/core/utils/log_adapter.h b/mindspore/core/utils/log_adapter.h index 7020a30e7a..3865d92fdf 100644 --- a/mindspore/core/utils/log_adapter.h +++ b/mindspore/core/utils/log_adapter.h @@ -179,7 +179,7 @@ class LogWriter { excp_type) ^ \ mindspore::LogStream() -#define IS_OUTPUT_ON(level) (level) >= mindspore::g_ms_submodule_log_levels[SUBMODULE_ID] +#define IS_OUTPUT_ON(level) ((level) >= mindspore::g_ms_submodule_log_levels[SUBMODULE_ID]) #define MS_LOG(level) MS_LOG_##level diff --git a/mindspore/core/utils/ordered_set.h b/mindspore/core/utils/ordered_set.h index 5336ceb38c..3f10acd2ac 100644 --- a/mindspore/core/utils/ordered_set.h +++ b/mindspore/core/utils/ordered_set.h @@ -146,13 +146,20 @@ class OrderedSet { MS_LOG(EXCEPTION) << "pop() on empty OrderedSet"; } - T back() { + T &back() { if (ordered_data_.size() != 0) { return ordered_data_.back(); } MS_LOG(EXCEPTION) << "back() on empty OrderedSet"; } + T &front() { + if (ordered_data_.size() != 0) { + return ordered_data_.front(); + } + MS_LOG(EXCEPTION) << "front() on empty OrderedSet"; + } + // Return true if there are no common elements bool is_disjoint(const OrderedSet &other) { for (auto &item : other.ordered_data_) { diff --git a/mindspore/core/utils/parallel_node_check.cc b/mindspore/core/utils/parallel_node_check.cc index 770f260c5e..52215fdd54 100644 --- a/mindspore/core/utils/parallel_node_check.cc +++ b/mindspore/core/utils/parallel_node_check.cc @@ -31,7 +31,7 @@ static const std::set PARALLEL_BLACK_LIST_ = {prim::kTupleGetItem, "get_ref_value", "get_ref_origin", "dot", "im2col", "col2im", "im2col_v1", "state_setitem", "ScalarSummary", "ImageSummary", "TensorSummary", "Debug", "HistogramSummary", "col2im_v1", "resolve", "BroadcastGradientArgs", "InvertPermutation", "ControlDepend", "DropoutGenMask", "embed", "create_instance", "RefToEmbed", - "stop_gradient", "Send"}; + "stop_gradient", "Send", "UpdateState", "Load"}; // clang-format on bool IsInParallelBlackList(const PrimitivePtr &prim) { diff --git a/mindspore/nn/_graph_kernels/graph_kernels.py b/mindspore/nn/_graph_kernels/graph_kernels.py deleted file mode 100644 index 869a67bd59..0000000000 --- a/mindspore/nn/_graph_kernels/graph_kernels.py +++ /dev/null @@ -1,259 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -""" -Graph kernels. They are composites of basic primitives and can be compiled into -a fused kernel automatically when context.set_context(enable_graph_kernel=True). -""" -from ...ops import operations as P -from ...ops.primitive import PrimitiveWithInfer, prim_attr_register -from ...ops.composite import multitype_ops as C -from ..cell import GraphKernel - - -class InplaceAssign(PrimitiveWithInfer): - """ - Inplace assign `Parameter` with a value. - - This primitive can only be used in graph kernel. - - Inputs: - - **variable** (Parameter) - The `Parameter`. - - **value** (Tensor) - The value to be assigned. - - **depend** (Tensor) - The dependent tensor to keep this op connected in graph. - - Outputs: - Tensor, has the same type as original `variable`. - - Examples: - >>> class MyClass(GraphKernel): - ... def __init__(self): - ... super(MyClass, self).__init__() - ... self.mul = P.Mul() - ... self.fake_output_assign = InplaceAssign() - ... self.fake_output_assign.add_prim_attr("fake_output", True) - ... - ... def construct(self, i0, i1): - ... mul_res = self.mul(i0, i1) - ... # mul_res is a fake output and parameter i0 will be updated. - ... mul_res = self.fake_output_assign(i0, mul_res, mul_res) - ... return mul_res - """ - - @prim_attr_register - def __init__(self): - super(InplaceAssign, self).__init__("InplaceAssign") - self.init_prim_io_names(inputs=['x', 'y', 'z'], outputs=['output']) - - def infer_shape(self, x, y, z): - return z - - def infer_dtype(self, x, y, z): - return z - - def get_bprop(self): - def bprop(x, y, z, out, dout): - return (x, C.zeros_like(y), dout) - - return bprop - - -class LambUpdateWithLR(GraphKernel): - r""" - Part of Lamb optimizer. - - .. math:: - s_1 = select(i_1 \gt y_g, select(i_0 \gt y_g, \frac{i_1}{i_2}, se), se) - i_5 = i_5 - max(min(s_1, y_m), y_g) \times i_3 \times i_4 - - Inputs: - - **input0** (Tensor) - The first tensor to be computed. - - **input1** (Tensor) - The second tensor to be computed. - - **input2** (Tensor) - The third tensor to be computed. - - **input3** (Tensor) - The fourth tensor to be computed. - - **input4** (Tensor) - The fifth tensor to be computed. - - **input5** (Tensor) - The sixth tensor to be computed. It will be updated by result. - - **greater_y** (Tensor) - The seventh tensor to be computed. - - **select_e** (Tensor) - The eighth tensor to be computed. - - **minimum_y** (Tensor) - The ninth tensor to be computed. - - Outputs: - A fake output tensor. - - Examples: - >>> import numpy as np - >>> import mindspore.context as context - >>> from mindspore.common import dtype as mstype - >>> from mindspore.common.tensor import Tensor - >>> from mindspore.common.parameter import Parameter - >>> from mindspore.nn.cell import Cell - >>> class Net(Cell): - ... def __init__(self, i5): - ... super(Net, self).__init__() - ... self.i5 = Parameter(i5, name='i5') - ... self.lamb_update = LambUpdateWithLR() - ... - ... def construct(self, i0, i1, i2, i3, i4, i6, i7, i8): - ... return self.lamb_update(i0, i1, i2, i3, i4, self.i5, i6, i7, i8) - >>> shape = [1, 16] - >>> oshape = [1] - >>> i0 = Tensor(np.random.normal(0, 1, oshape).astype(np.float32)) - >>> i1 = Tensor(np.random.normal(0, 1, oshape).astype(np.float32)) - >>> i2 = Tensor(np.random.normal(0, 1, oshape).astype(np.float32)) - >>> i3 = Tensor(np.random.normal(0, 1, oshape).astype(np.float32)) - >>> i4 = Tensor(np.random.normal(0, 1, shape).astype(np.float32)) - >>> i5 = Tensor(np.random.normal(0, 1, shape).astype(np.float32)) - >>> i6 = Tensor(np.random.normal(0, 1, oshape).astype(np.float32)) - >>> i7 = Tensor(np.random.normal(0, 1, oshape).astype(np.float32)) - >>> i8 = Tensor(np.random.normal(0, 1, oshape).astype(np.float32)) - >>> context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True) - >>> net = Net(i5) - >>> _ = net(i0, i1, i2, i3, i4, i6, i7, i8) - >>> output = (net.i5) - """ - - def __init__(self): - super(LambUpdateWithLR, self).__init__() - self.greater = P.Greater() - self.select = P.Select() - self.div = P.RealDiv() - self.min = P.Minimum() - self.max = P.Maximum() - self.mul = P.Mul() - self.sub = P.Sub() - self.fake_output_assign = InplaceAssign() - self.fake_output_assign.add_prim_attr("fake_output", True) - - def construct(self, input0, input1, input2, input3, input4, input5, greater_y, select_e, minimum_y): - greater0 = self.greater(input0, greater_y) - greater1 = self.greater(input1, greater_y) - real_div0 = self.div(input1, input2) - select0 = self.select(greater0, real_div0, select_e) - select1 = self.select(greater1, select0, select_e) - min0 = self.min(select1, minimum_y) - max0 = self.max(min0, greater_y) - mul0 = self.mul(max0, input3) - mul1 = self.mul(mul0, input4) - sub0 = self.sub(input5, mul1) - sub0 = self.fake_output_assign(input5, sub0, sub0) - return sub0 - - -class LambNextMV(GraphKernel): - r""" - Part of Lamb optimizer. - - .. math:: - rd_0 = \frac{i_8 \times i_5 + i_9 \times i_4}{i6} - rd_1 = \frac{x_0 \times i_2 + x_1 \times i_1}{i3} - y_2 = \frac{rd_0}{\sqrt{rd_1 + x3}} + x_2 \times i_7 - y_3 = \frac{rd_0}{\sqrt{rd_1} + x3} - i5 = i_8 \times i_5 + i_9 \times i_4 - i2 = x_0 \times i_2 + x_1 \times i_1 - - Inputs: - - **inputs1** (Tensor) - The first input tensor to be computed. - - **inputs2** (Tensor) - The second input tensor to be computed. It will be updated by result. - - **inputs3** (Tensor) - The third input tensor to be computed. - - **inputs4** (Tensor) - The fourth input tensor to be computed. - - **inputs5** (Tensor) - The fifth input tensor to be computed. It will be updated by result. - - **inputs6** (Tensor) - The sixth input tensor to be computed. - - **inputs7** (Tensor) - The seventh input tensor to be computed. - - **inputs8** (Tensor) - The eighth input tensor to be computed. - - **inputs9** (Tensor) - The ninth input tensor to be computed. - - **inputsx0** (Tensor) - The tenth input tensor to be computed. - - **inputsx1** (Tensor) - The eleventh input tensor to be computed. - - **inputsx2** (Tensor) - The twelfth input tensor to be computed. - - **inputsx3** (Tensor) - The thirteenth input tensor to be computed. - - Outputs: - Tuple of 2 Tensors. - - - **add3** (Tensor) - the shape is the same as the one after broadcasting, and the data type is - the one with higher precision or higher digits among the inputs. - - **realdiv4** (Tensor) - the shape is the same as the one after broadcasting, and the data type is - the one with higher precision or higher digits among the inputs. - - Examples: - >>> import numpy as np - >>> import mindspore.context as context - >>> from mindspore.common import dtype as mstype - >>> from mindspore.common.tensor import Tensor - >>> from mindspore.common.parameter import Parameter - >>> from mindspore.nn.cell import Cell - >>> class Net(Cell): - ... def __init__(self, i1, i4): - ... super(Net, self).__init__() - ... self.i1 = Parameter(i1, name='i1') - ... self.i4 = Parameter(i4, name='i4') - ... self.lamb_next = LambNextMV() - ... - ... def construct(self, i0, i2, i3, i5, i6, i7, i8, i9, i10, i11, i12): - ... i0_ = i0 + i2 - ... return self.lamb_next(i0_, self.i1, i2, i3, self.i4, i5, i6, i7, i8, i9, i10, i11, i12) - >>> shape = [1, 16] - >>> i0 = Tensor(np.abs(np.random.normal(0, 1, shape)).astype(np.float32)) - >>> i1 = Tensor(np.abs(np.random.normal(0, 1, shape)).astype(np.float32)) - >>> i2 = Tensor(np.abs(np.random.normal(0, 1, shape)).astype(np.float32)) - >>> i3 = Tensor(np.random.normal(0, 1, shape).astype(np.float32)) - >>> i4 = Tensor(np.random.normal(0, 1, shape).astype(np.float32)) - >>> i5 = Tensor(np.abs(np.random.normal(0, 1, shape)).astype(np.float32)) - >>> i6 = Tensor(np.random.normal(0, 1, shape).astype(np.float32)) - >>> i7 = Tensor(np.random.normal(0, 1, shape).astype(np.float32)) - >>> i8 = Tensor(np.random.normal(0, 1, shape).astype(np.float32)) - >>> i9 = Tensor(np.abs(np.random.normal(0, 1, shape)).astype(np.float32)) - >>> i10 = Tensor(np.abs(np.random.normal(0, 1, shape)).astype(np.float32)) - >>> i11 = Tensor(np.random.normal(0, 1, shape).astype(np.float32)) - >>> i12 = Tensor(np.ones(shape).astype(np.float32) * 1e-6) - >>> context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True) - >>> net = Net(i1, i4) - >>> (o0, o1) = net(i0, i2, i3, i5, i6, i7, i8, i9, i10, i11, i12) - >>> output = (o0, net.i4, net.i1, o1) - """ - - def __init__(self): - super(LambNextMV, self).__init__() - self.mul = P.Mul() - self.add = P.Add() - self.div = P.RealDiv() - self.sqrt = P.Sqrt() - self.rsqrt = P.Rsqrt() - self.fake_output_assign_1 = InplaceAssign() - self.fake_output_assign_1.add_prim_attr("fake_output", False) - self.fake_output_assign_2 = InplaceAssign() - self.fake_output_assign_2.add_prim_attr("fake_output", False) - - def construct(self, input1, input2, input3, input4, input5, input6, input7, - input8, input9, inputx0, inputx1, inputx2, inputx3): - mul3 = self.mul(inputx1, input1) - mul2 = self.mul(inputx0, input2) - add1 = self.add(mul2, mul3) - realdiv1 = self.div(add1, input3) - add2 = self.add(realdiv1, inputx3) - sqrt0 = self.rsqrt(add2) - sqrt1 = self.sqrt(realdiv1) - add4 = self.add(sqrt1, inputx3) - mul1 = self.mul(input9, input4) - mul0 = self.mul(input8, input5) - add0 = self.add(mul0, mul1) - realdiv0 = self.div(add0, input6) - realdiv2 = self.mul(realdiv0, sqrt0) - realdiv4 = self.div(realdiv0, add4) - mul4 = self.mul(inputx2, input7) - add3 = self.add(realdiv2, mul4) - - add3 = self.fake_output_assign_1(input5, add0, add3) - add3 = self.fake_output_assign_2(input2, add1, add3) - - return add3, realdiv4 diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 62787a9a91..3703566ff6 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -375,8 +375,8 @@ class FakeQuantWithMinMaxObserver(UniformQuantObserver): def construct(self, x): if self.training: min_up, max_up = self.ema_update(x, self.minq, self.maxq) - P.Assign()(self.minq, min_up) - P.Assign()(self.maxq, max_up) + self.minq = min_up + self.maxq = max_up out = self.fake_quant_train(x, self.minq, self.maxq) else: out = self.fake_quant_infer(x, self.minq, self.maxq) @@ -765,14 +765,14 @@ class Conv2dBnFoldQuant(Cell): if self.training: out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std, running_mean, self.step) - F.control_depend(out, self.assignadd(self.step, self.one)) + self.assignadd(self.step, self.one) else: out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, running_std, running_mean, self.step) else: if self.training: out = self.batchnorm_fold2_train(out, self.beta, self.gamma, batch_std, batch_mean, running_std) - F.control_depend(out, self.assignadd(self.step, self.one)) + self.assignadd(self.step, self.one) else: out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, running_std, running_mean, running_std) return out diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 1a68287f9e..23515cbea2 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -113,8 +113,8 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, op_sqrt = P.Sqrt() scatter_add = P.ScatterAdd(use_locking) - assign_m = F.assign(m, op_mul(beta1, m)) - assign_v = F.assign(v, op_mul(beta2, v)) + success = F.depend(success, F.assign(m, op_mul(beta1, m))) + success = F.depend(success, F.assign(v, op_mul(beta2, v))) grad_indices = gradient.indices grad_value = gradient.values @@ -129,27 +129,18 @@ def _run_opt_with_sparse(opt, sparse_opt, push, pull, use_locking, use_nesterov, if use_nesterov: m_temp = next_m * _scaler_ten - assign_m_nesterov = F.assign(m, op_mul(beta1, next_m)) + F.assign(m, op_mul(beta1, next_m)) div_value = scatter_add(m, op_mul(grad_indices, _scaler_one), op_mul(F.tuple_to_array((1.0,)) - beta1, grad_value)) param_update = div_value / (op_sqrt(next_v) + eps) - - m_recover = F.assign(m, m_temp / _scaler_ten) - - F.control_depend(m_temp, assign_m_nesterov) - F.control_depend(assign_m_nesterov, div_value) - F.control_depend(param_update, m_recover) + F.assign(m, m_temp / _scaler_ten) else: param_update = next_m / (op_sqrt(next_v) + eps) lr_t = lr * op_sqrt(1 - beta2_power) / (1 - beta1_power) - next_param = param - lr_t * param_update - F.control_depend(assign_m, next_m) - F.control_depend(assign_v, next_v) - success = F.depend(success, F.assign(param, next_param)) success = F.depend(success, F.assign(m, next_m)) success = F.depend(success, F.assign(v, next_v)) @@ -289,7 +280,7 @@ class Adam(Optimizer): ... {'order_params': net.trainable_params()}] >>> optim = nn.Adam(group_params, learning_rate=0.1, weight_decay=0.0) >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. - >>> # The no_conv_params's parameters will use learning rate of 0.01 and defaule weight decay of 0.0. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() @@ -564,7 +555,7 @@ class AdamOffload(Optimizer): ... {'order_params': net.trainable_params()}] >>> optim = nn.AdamOffload(group_params, learning_rate=0.1, weight_decay=0.0) >>> # The conv_params's parameters will use default learning rate of 0.1 and weight decay of 0.01. - >>> # The no_conv_params's parameters will use learning rate of 0.01 and defaule weight decay of 0.0. + >>> # The no_conv_params's parameters will use learning rate of 0.01 and default weight decay of 0.0. >>> # The final parameters order in which the optimizer will be followed is the value of 'order_params'. >>> >>> loss = nn.SoftmaxCrossEntropyWithLogits() diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index cffabcf393..0d54160c39 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,12 +26,13 @@ from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from .optimizer import Optimizer from .. import layer -from .. import _graph_kernels as G + num_one = Tensor(np.ones([1]), mstype.float32) _lamb_opt = C.MultitypeFuncGraph("lamb_opt") + @_lamb_opt.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") def _update_run_op(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag, optim_filter): @@ -158,67 +159,6 @@ def _update_run_op_ascend(beta1, beta2, eps, global_step, lr, weight_decay, para return gradient -lamb_opt_graph_kernel = C.MultitypeFuncGraph("lamb_opt_graph_kernel") - - -@lamb_opt_graph_kernel.register("Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", - "Tensor", "Tensor", "Tensor", "Tensor", "Bool") -def _update_run_op_graph_kernel(beta1, beta2, eps, global_step, lr, weight_decay, param, m, v, gradient, decay_flag): - """ - Update parameters. - - Args: - beta1 (Tensor): The exponential decay rate for the 1st moment estimations. Should be in range (0.0, 1.0). - beta2 (Tensor): The exponential decay rate for the 2nd moment estimations. Should be in range (0.0, 1.0). - eps (Tensor): Term added to the denominator to improve numerical stability. Should be greater than 0. - lr (Tensor): Learning rate. - weight_decay (Number): Weight decay. Should be equal to or greater than 0. - global_step (Tensor): Global step. - param (Tensor): Parameters. - m (Tensor): m value of parameters. - v (Tensor): v value of parameters. - gradient (Tensor): Gradient of parameters. - decay_flag (bool): Specifies whether param update with weight decay. - - Returns: - Tensor, the new value of v after updating. - """ - op_mul = P.Mul() - op_square = P.Square() - op_cast = P.Cast() - op_shape = P.Shape() - op_pow = P.Pow() - op_norm = layer.Norm() - op_fill = P.Fill() - op_dtype = P.DType() - - param_fp32 = op_cast(param, mstype.float32) - gradient_fp32 = op_cast(gradient, mstype.float32) - - i6_ex = op_cast(global_step + num_one, mstype.float32) - i9 = op_cast(num_one, mstype.float32) - beta1 - x1 = op_cast(num_one, mstype.float32) - beta2 - i6 = op_cast(num_one, mstype.float32) - op_pow(beta1, i6_ex) - i3 = op_cast(num_one, mstype.float32) - op_pow(beta2, i6_ex) - i1 = op_square(gradient_fp32) - add3, update = G.LambNextMV()(i1, v, i3, gradient, m, i6, param, beta1, i9, beta2, x1, weight_decay, eps) - - if decay_flag: - update = update + op_mul(weight_decay, param_fp32) - - w_norm = op_norm(param_fp32) - g_norm = op_norm(gradient_fp32) - g_norm_hat = op_norm(add3) - - zeros = F.zeros_like(w_norm) - ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) - tens = op_fill(op_dtype(w_norm), op_shape(w_norm), 10.0) - - next_param = G.LambUpdateWithLR()(g_norm, w_norm, g_norm_hat, lr, update, param, zeros, ones, tens) - next_v = F.control_depend(add3, next_param) - return next_v - - def _check_param_value(beta1, beta2, eps, prim_name): validator.check_value_type("beta1", beta1, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name) @@ -323,51 +263,32 @@ class Lamb(Optimizer): self.global_step = Parameter(initializer(0, [1]), name='global_step') self.assignadd = P.AssignAdd() self.hyper_map = C.HyperMap() - self.enable_graph_kernel = context.get_context("device_target") == "Ascend" and \ - context.get_context("enable_graph_kernel") self.device_ascend = context.get_context("device_target") == "Ascend" def construct(self, gradients): lr = self.get_lr() - if self.enable_graph_kernel: - if self.is_group: - if self.is_group_lr: - optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2, self.eps, - self.global_step), - lr, self.weight_decay, self.params, self.moments1, self.moments2, - gradients, self.decay_flags) - else: - optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2, self.eps, - self.global_step, lr), - self.weight_decay, self.params, self.moments1, self.moments2, - gradients, self.decay_flags) - else: - optim_result = self.hyper_map(F.partial(lamb_opt_graph_kernel, self.beta1, self.beta2, self.eps, - self.global_step, lr, self.weight_decay), - self.params, self.moments1, self.moments2, gradients, self.decay_flags) - else: - lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt - if self.is_group: - if self.is_group_lr: - optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, - self.global_step), - lr, self.weight_decay, self.params, self.moments1, self.moments2, - gradients, self.decay_flags, self.optim_filter) - else: - optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, - self.global_step, lr), - self.weight_decay, self.params, self.moments1, self.moments2, - gradients, self.decay_flags, self.optim_filter) + lamb_opt = _lamb_opt_ascend if self.device_ascend else _lamb_opt + if self.is_group: + if self.is_group_lr: + optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, + self.global_step), + lr, self.weight_decay, self.params, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) else: optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, - self.global_step, lr, self.weight_decay), - self.params, self.moments1, self.moments2, gradients, - self.decay_flags, self.optim_filter) + self.global_step, lr), + self.weight_decay, self.params, self.moments1, self.moments2, + gradients, self.decay_flags, self.optim_filter) + else: + optim_result = self.hyper_map(F.partial(lamb_opt, self.beta1, self.beta2, self.eps, + self.global_step, lr, self.weight_decay), + self.params, self.moments1, self.moments2, gradients, + self.decay_flags, self.optim_filter) if self.use_parallel: - self.broadcast_params(optim_result) + optim_result = F.depend(optim_result, self.broadcast_params(optim_result)) if not self.dynamic_lr: - F.control_depend(lr, self.assignadd(self.global_step, 1)) + optim_result = F.depend(optim_result, self.assignadd(self.global_step, 1)) return optim_result diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index cbbba57dec..ac21ec9f4d 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -27,15 +27,14 @@ _momentum_opt = C.MultitypeFuncGraph("momentum_opt") @_momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Bool", "Bool") def _tensor_run_opt_ext(opt, momentum, learning_rate, gradient, weight, moment, ps_parameter, cache_enable): """Apply momentum optimizer to the weight parameter using Tensor.""" - success = True if ps_parameter and not cache_enable: op_shape = P.Shape() _ps_pull = P.Pull() _ps_push = P.Push("ApplyMomentum", []) shapes = (op_shape(learning_rate), op_shape(gradient), op_shape(momentum)) - success = F.depend(success, _ps_pull(_ps_push((learning_rate, gradient, momentum), shapes), weight)) + success = F.depend(True, _ps_pull(_ps_push((learning_rate, gradient, momentum), shapes), weight)) else: - success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) + success = F.depend(True, opt(weight, moment, learning_rate, gradient, momentum)) return success diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index e3baaef6c8..2ea608487c 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -447,7 +447,7 @@ class Optimizer(Cell): else: lr = self.learning_rate(self.global_step) - F.control_depend(lr, self.assignadd(self.global_step, self.global_step_increase_tensor)) + self.assignadd(self.global_step, self.global_step_increase_tensor) return lr def get_lr_parameter(self, param): @@ -526,11 +526,7 @@ class Optimizer(Cell): new_param_group.append(next_params) for i in range(F.tuple_len(next_params)): F.assign(key_group[root][i], next_params[i]) - status = F.control_depend(optim_result, new_param_group[0][0]) - for i in range(self.dev_num - 1): - status = F.depend(F.control_depend(new_param_group[i], new_param_group[i+1][0]), status) - - return status + return new_param_group def construct(self, *hyper_params): raise NotImplementedError diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 0cb9e44a4f..bd80e7369b 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -248,7 +248,8 @@ class TrainOneStepCell(Cell): sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) grads = self.grad(self.network, weights)(*inputs, sens) grads = self.grad_reducer(grads) - return F.depend(loss, self.optimizer(grads)) + loss = F.depend(loss, self.optimizer(grads)) + return loss class GetNextSingleOp(Cell): @@ -291,7 +292,7 @@ class _VirtualDatasetCell(Cell): """ Wrap the network with virtual dataset to convert data parallel layout to model parallel layout. - _VirtualDataset is a virtual Primitive, it does not exist in the final executing graph. Inputs and outpus + _VirtualDataset is a virtual Primitive, it does not exist in the final executing graph. Inputs and outputs of _VirtualDataset are distributed in data parallel pattern, tensor redistribution Primitives is inserted dynamically during the graph compile process. diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index 59f4652423..304996a4c4 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -22,8 +22,7 @@ from ...common.parameter import Parameter from ...ops import functional as F from ...ops import composite as C from ...ops import operations as P -from ...ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual, \ - ControlDepend +from ...ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual from ...common import dtype as mstype _grad_scale = C.MultitypeFuncGraph("grad_scale") @@ -139,16 +138,15 @@ class DynamicLossScaleUpdateCell(Cell): should_inc = self.less_equal(self.scale_window, self.cur_iter - self.last_overflow_iter) last_iter_cond = self.logic_or(overflow_cond, should_inc) last_overflow_iter = self.select(last_iter_cond, self.cur_iter, self.last_overflow_iter) - assign_last_iter = F.assign(self.last_overflow_iter, last_overflow_iter) + last_iter = F.assign(self.last_overflow_iter, last_overflow_iter) update_scale_cond = self.logic_and(should_inc, self.logic_not(overflow_cond)) scale_mul_res = loss_scale_on_overflow * self.scale_factor scaled_loss_scale = self.select(update_scale_cond, scale_mul_res, loss_scale_on_overflow) - assign_scaled_loss_scale = F.assign(loss_scale, scaled_loss_scale) + F.assign(loss_scale, scaled_loss_scale) inc_cur_iter = self.cur_iter + 1 - assing_cur_iter = F.assign(self.cur_iter, inc_cur_iter) - t = (assign_last_iter, assign_scaled_loss_scale, assing_cur_iter) - F.control_depend(assign_last_iter, assing_cur_iter) - return F.depend(overflow, t) + inc_cur_iter = F.depend(inc_cur_iter, last_iter) + F.assign(self.cur_iter, inc_cur_iter) + return overflow class FixedLossScaleUpdateCell(Cell): @@ -290,7 +288,6 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): self.reduce_sum = ReduceSum(keep_dims=False) self.base = Tensor(1, mstype.float32) self.less_equal = LessEqual() - self.depend_parameter_use = ControlDepend(depend_mode=1) self.allreduce = P.AllReduce() self.is_distributed = self.parallel_mode != ParallelMode.STAND_ALONE @@ -307,7 +304,6 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): else: raise TypeError("The scale_sense must be Cell or Tensor, but got {}".format(type(scale_sense))) - @C.add_flags(has_effect=True) def construct(self, *inputs): weights = self.weights loss = self.network(*inputs) @@ -315,8 +311,10 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): if not self.gpu_target: # init overflow buffer init = self.alloc_status() - # clear overflow buffer - self.clear_status(init) + # clear overflow buffer after loss calculated + init = F.depend(init, loss) + clear_status = self.clear_status(init) + loss = F.depend(loss, clear_status) scaling_sens = self.scale_sense scaling_sens_filled = C.ones_like(loss) * F.cast(scaling_sens, F.dtype(loss)) @@ -326,7 +324,10 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): grads = self.grad_reducer(grads) # get the overflow buffer if not self.gpu_target: - self.get_status(init) + # get overflow status after grads calculated + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) # sum overflow buffer elements, 0:not overflow , >0:overflow flag_sum = self.reduce_sum(init, (0,)) else: @@ -344,12 +345,9 @@ class TrainOneStepWithLossScaleCell(TrainOneStepCell): if self.loss_scaling_manager is not None: overflow = self.loss_scaling_manager(self.scale_sense, cond) # if there is no overflow, do optimize - if overflow: - opt = False - else: - opt = self.optimizer(grads) - ret = (loss, cond, scaling_sens) - return F.depend(ret, opt) + if not overflow: + loss = F.depend(loss, self.optimizer(grads)) + return loss, cond, scaling_sens def set_sense_scale(self, sens): """If the user has set the sens in the training process and wants to reassign the value, he can call diff --git a/mindspore/ops/_grad/grad_implementations.py b/mindspore/ops/_grad/grad_implementations.py index 9c9a047054..cf9d88c25b 100644 --- a/mindspore/ops/_grad/grad_implementations.py +++ b/mindspore/ops/_grad/grad_implementations.py @@ -266,14 +266,28 @@ def bprop_control_depend(x, y, out, dout): """Backpropagator for primitive `Control_depend`.""" return C.zeros_like(x), C.zeros_like(y) + @bprops.register("switch") def bprop_switch(cond, tb, fb, out, dout): """Backpropagator for primitive `switch`.""" return C.zeros_like(cond), F.switch(cond, dout, C.zeros_like(tb)), \ - F.switch(cond, C.zeros_like(fb), dout) + F.switch(cond, C.zeros_like(fb), dout) + def _fprop_switch_layer(index, layers): """Backpropagator for primitive `switch_layer`.""" def _bprop_switch_layer(dout): return dout, C.zeros_like(index), () return F.switch_layer(index, layers), _bprop_switch_layer + + +@bprops.register("UpdateState") +def bprop_update_state(u_monad, x, out, dout): + """Backpropagator for primitive `UpdateState`.""" + return C.zeros_like(u_monad), C.zeros_like(x) + + +@bprops.register("Load") +def bprop_load(param, u_monad, out, dout): + """Backpropagator for primitive `load`.""" + return dout, C.zeros_like(u_monad) diff --git a/mindspore/ops/_grad/grad_other_ops.py b/mindspore/ops/_grad/grad_other_ops.py index 799b722ad8..2f9a6d3cbe 100644 --- a/mindspore/ops/_grad/grad_other_ops.py +++ b/mindspore/ops/_grad/grad_other_ops.py @@ -26,7 +26,7 @@ from .grad_base import bprop_getters def get_bprop_assign(self): """Generate bprop for Assign""" def bprop(x, y, out, dout): - return (x, zeros_like(y)) + return (dout, zeros_like(y)) return bprop diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py index fad50318a9..a12be4da0c 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py @@ -37,9 +37,7 @@ fake_quant_perchannel_op_info = TBERegOp("FakeQuantPerChannel") \ .input(2, "max", None, "required", None) \ .output(0, "y", True, "required", "all") \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ - .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ .get_op_info() diff --git a/mindspore/ops/composite/multitype_ops/add_impl.py b/mindspore/ops/composite/multitype_ops/add_impl.py index 76b43120d0..41e29a869d 100644 --- a/mindspore/ops/composite/multitype_ops/add_impl.py +++ b/mindspore/ops/composite/multitype_ops/add_impl.py @@ -129,6 +129,21 @@ def _tensor_add_tensor(x, y): return F.tensor_add(x, y) +@add.register("RowTensor", "Tensor") +def add_rowtensor_tensor(x, y): + """ + Adds RowTensor and Tensor. + + Args: + x (RowTensor): x + y (Tensor): y + + Returns: + RowTensor, the dtype is same as x. + """ + return F.row_tensor_add(x, y) + + @_add_backward.register("EnvType", "EnvType") def _add_env(x, y): """ @@ -174,4 +189,47 @@ def _add_addn(x, y): return F.addn((x, y)) +@_add_backward.register("UMonad", "UMonad") +def _add_umonad_umonad(x, y): + """ + Adds two monad. + + Args: + x (UMonad): x + y (UMonad): y + + Returns: + Monad, the dtype is same as x. + """ + return x + +@_add_backward.register("IOMonad", "IOMonad") +def _add_iomonad_iomonad(x, y): + """ + Adds two monad. + + Args: + x (IOMonad): x + y (IOMonad): y + + Returns: + Monad, the dtype is same as x. + """ + return x + +@_add_backward.register("RowTensor", "Tensor") +def _add_rowtensor_tensor(x, y): + """ + Adds RowTensor and Tensor. + + Args: + x (RowTensor): x + y (Tensor): y + + Returns: + RowTensor, the dtype is same as x. + """ + return x + y + + hyper_add = base.HyperMap(_add_backward) diff --git a/mindspore/ops/composite/multitype_ops/zeros_like_impl.py b/mindspore/ops/composite/multitype_ops/zeros_like_impl.py index e72cdf2ad3..f1659139ac 100644 --- a/mindspore/ops/composite/multitype_ops/zeros_like_impl.py +++ b/mindspore/ops/composite/multitype_ops/zeros_like_impl.py @@ -120,6 +120,34 @@ def _zeros_like_dict(x): return F.make_dict(keys, new_values) +@zeros_like_leaf.register("UMonad") +def _zeros_like_u_monad(x): + """ + U Monad. + + Args: + x (UMonad): + + Returns: + x. + """ + return x + + +@zeros_like_leaf.register("IOMonad") +def _zeros_like_io_monad(x): + """ + IO Monad. + + Args: + x (IOMonad): + + Returns: + x. + """ + return x + + # zeros_like is an object that will generate graph of zero_like operation for different type zeros_like = base.HyperMap(zeros_like_leaf) """`zeros_like` is an object that will generate graph of `zero_like` operation for different type.""" diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 60026e707f..21d6c906c7 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -177,6 +177,7 @@ make_row_tensor = Primitive('MakeRowTensor') row_tensor_get_values = Primitive('RowTensorGetValues') row_tensor_get_indices = Primitive('RowTensorGetIndices') row_tensor_get_dense_shape = Primitive('RowTensorGetDenseShape') +row_tensor_add = Primitive('RowTensorAdd') make_sparse_tensor = Primitive('MakeSparseTensor') sparse_tensor_get_values = Primitive('SparseTensorGetValues') diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 71adbf54b7..79363ad002 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -87,7 +87,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam from . import _quant_ops from ._quant_ops import * from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode, - ConfusionMatrix, PopulationCount, + ConfusionMatrix, PopulationCount, UpdateState, Load, CheckValid, Partial, Depend, identity, CheckBprop, Push, Pull) from ._thor_ops import (CusBatchMatMul, CusCholeskyTrsm, CusFusedAbsMax1, CusImg2Col, CusMatMulCubeDenseLeft, CusMatMulCubeFraczRightMul, CusMatMulCube, CusMatrixCombine, CusTranspose02314, @@ -300,6 +300,7 @@ __all__ = [ 'Partial', 'MakeRefKey', 'Depend', + 'UpdateState', 'identity', 'AvgPool', # Back Primitive diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index b56df82935..d557d457f9 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -59,6 +59,7 @@ class _ScatterOp(PrimitiveWithInfer): """Initialize _ScatterOp""" validator.check_value_type('use_locking', use_locking, [bool], self.name) self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, x_shape, indices_shape, updates_shape): self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) @@ -98,6 +99,7 @@ class _ScatterOp_Dynamic(PrimitiveWithCheck): """Initialize _ScatterOp_Dynamic""" validator.check_value_type('use_locking', use_locking, [bool], self.name) self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) + self.add_prim_attr('side_effect_mem', True) def check_shape(self, x_shape, indices_shape, updates_shape): self._check_scatter_shape(x_shape, indices_shape, updates_shape, self.name) @@ -3320,6 +3322,8 @@ class ScatterUpdate(_ScatterOp_Dynamic): """Initialize ScatterUpdate""" validator.check_value_type('use_locking', use_locking, [bool], self.name) self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) + self.add_prim_attr('side_effect_mem', True) + class ScatterNdUpdate(_ScatterNdOp): @@ -3375,6 +3379,7 @@ class ScatterNdUpdate(_ScatterNdOp): """Initialize ScatterNdUpdate""" validator.check_value_type('use_locking', use_locking, [bool], self.name) self.init_prim_io_names(inputs=['x', 'indices', 'value'], outputs=['y']) + self.add_prim_attr('side_effect_mem', True) def infer_dtype(self, x_dtype, indices_dtype, value_dtype): validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) @@ -3427,12 +3432,6 @@ class ScatterMax(_ScatterOp): [88. 88. 88.]] """ - @prim_attr_register - def __init__(self, use_locking=True): - """Initialize ScatterMax""" - self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) - validator.check_value_type('use_locking', use_locking, (bool,), self.name) - class ScatterMin(_ScatterOp): r""" @@ -3528,6 +3527,7 @@ class ScatterAdd(_ScatterOp_Dynamic): """Initialize ScatterAdd""" validator.check_value_type('use_locking', use_locking, [bool], self.name) self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) + self.add_prim_attr('side_effect_mem', True) class ScatterSub(_ScatterOp): @@ -3800,6 +3800,7 @@ class ScatterNonAliasingAdd(_ScatterNdOp): def __init__(self): """Initialize ScatterNonAliasingAdd""" self.init_prim_io_names(inputs=['x', 'indices', 'updates'], outputs=['y']) + self.add_prim_attr('side_effect_mem', True) def infer_dtype(self, x_dtype, indices_dtype, updates_dtype): validator.check_tensor_dtype_valid('indices', indices_dtype, [mstype.int32], self.name) @@ -4919,9 +4920,13 @@ class Identity(PrimitiveWithInfer): [1 2 3 4] """ + # Side effect is identity with input. + side_effect_propagate = 1 + @prim_attr_register def __init__(self): """Initialize identity""" + self.add_prim_attr('side_effect_propagate', 1) def __infer__(self, x): validator.check_subclass("x", x['dtype'], mstype.tensor, self.name) diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 3d607c6013..3d65a2ada4 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -79,6 +79,7 @@ class ScalarSummary(PrimitiveWithInfer): @prim_attr_register def __init__(self): """init""" + self.add_prim_attr("side_effect_io", True) def __infer__(self, name, value): _check_summary_param(name, value, self.__class__.__name__) @@ -119,6 +120,7 @@ class ImageSummary(PrimitiveWithInfer): @prim_attr_register def __init__(self): """init""" + self.add_prim_attr("side_effect_io", True) def __infer__(self, name, value): _check_summary_param(name, value, self.__class__.__name__) @@ -162,6 +164,7 @@ class TensorSummary(PrimitiveWithInfer): @prim_attr_register def __init__(self): """init""" + self.add_prim_attr("side_effect_io", True) def __infer__(self, name, value): _check_summary_param(name, value, self.__class__.__name__) @@ -204,6 +207,7 @@ class HistogramSummary(PrimitiveWithInfer): @prim_attr_register def __init__(self): """init""" + self.add_prim_attr("side_effect_io", True) def __infer__(self, name, value): _check_summary_param(name, value, self.__class__.__name__) @@ -267,6 +271,7 @@ class InsertGradientOf(PrimitiveWithInfer): @prim_attr_register def __init__(self, f): + self.add_prim_attr('side_effect_backprop', True) self.f = f def infer_shape(self, x_shape): @@ -373,7 +378,7 @@ class Print(PrimitiveWithInfer): @prim_attr_register def __init__(self): - self.add_prim_attr("_side_effect", True) + self.add_prim_attr("side_effect_io", True) def __call__(self, *args): for arg in args: @@ -383,7 +388,8 @@ class Print(PrimitiveWithInfer): return [1] def infer_dtype(self, *inputs): - for ele in inputs: + # check argument types except the last one (io state). + for ele in inputs[:-1]: if isinstance(ele, (tuple, list)): self.infer_dtype(*ele) else: diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 1f03c2c909..339c1830c6 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -229,6 +229,7 @@ class AssignAdd(PrimitiveWithInfer): def __init__(self): """Initialize AssignAdd""" self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, variable, value): return value @@ -284,6 +285,8 @@ class AssignSub(PrimitiveWithInfer): @prim_attr_register def __init__(self): """Initialize AssignSub""" + self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, variable, value): return value diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index d50552835d..ef9c002a7a 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -18,7 +18,6 @@ import math import operator from functools import reduce, partial -from mindspore import log as logger from mindspore._checkparam import _check_3d_int_or_tuple from mindspore import log as logger import numpy as np @@ -855,7 +854,7 @@ class FusedBatchNorm(Primitive): (128, 64, 32, 64) """ __mindspore_signature__ = ( - sig.make_sig('input_x', dtype=sig.sig_dtype.T2), + sig.make_sig('input_x', dtype=sig.sig_dtype.T1), sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), @@ -952,7 +951,7 @@ class FusedBatchNormEx(PrimitiveWithInfer): (128, 64, 32, 64) """ __mindspore_signature__ = ( - sig.make_sig('input_x', dtype=sig.sig_dtype.T2), + sig.make_sig('input_x', dtype=sig.sig_dtype.T1), sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), @@ -1955,6 +1954,11 @@ class Conv2DBackpropInput(PrimitiveWithInfer): >>> print(output.shape) (10, 32, 32, 32) """ + __mindspore_signature__ = ( + sig.make_sig('out_backprop', dtype=sig.sig_dtype.T), + sig.make_sig('filter', dtype=sig.sig_dtype.T1), + sig.make_sig('input_sizes', dtype=sig.sig_dtype.T2) + ) @prim_attr_register def __init__(self, @@ -2408,7 +2412,7 @@ class ApplyMomentum(PrimitiveWithInfer): sig.make_sig('accumulation', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('learning_rate', dtype=sig.sig_dtype.T1), sig.make_sig('gradient', dtype=sig.sig_dtype.T), - sig.make_sig('momentum', dtype=sig.sig_dtype.T2), + sig.make_sig('momentum', dtype=sig.sig_dtype.T2) ) @prim_attr_register @@ -2418,6 +2422,7 @@ class ApplyMomentum(PrimitiveWithInfer): validator.check_value_type('gradient_scale', gradient_scale, [float], self.name) self.init_prim_io_names(inputs=['variable', 'accumulation', 'learning_rate', 'gradient', 'momentum'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, v_shape, a_shape, l_shape, g_shape, m_shape): return v_shape @@ -2687,6 +2692,7 @@ class SGD(PrimitiveWithCheck): raise ValueError(f"Nesterov need zero dampening!") self.init_prim_io_names(inputs=['parameters', 'gradient', 'learning_rate', 'accum', 'momentum', 'stat'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) def check_shape(self, parameters_shape, gradient_shape, learning_rate_shape, accum_shape, momentum_shape, stat_shape): @@ -2771,6 +2777,7 @@ class ApplyRMSProp(PrimitiveWithInfer): self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) self.init_prim_io_names(inputs=['var', 'mean_square', 'moment', 'learning_rate', 'grad', 'rho', 'momentum', 'epsilon'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, var_shape, mean_square_shape, moment_shape, learning_rate_shape, grad_shape, decay_shape, momentum_shape, epsilon_shape): @@ -2868,6 +2875,7 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): @prim_attr_register def __init__(self, use_locking=False): self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape, momentum_shape, epsilon_shape): @@ -3045,6 +3053,7 @@ class DropoutGenMask(Primitive): validator.check_value_type("Seed0", Seed0, [int], self.name) validator.check_value_type("Seed1", Seed1, [int], self.name) self.add_prim_attr("_random_effect", True) + self.add_prim_attr('side_effect_mem', True) class DropoutDoMask(PrimitiveWithInfer): @@ -4065,6 +4074,7 @@ class Adam(PrimitiveWithInfer): def __init__(self, use_locking=False, use_nesterov=False): validator.check_value_type("use_locking", use_locking, [bool], self.name) validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name) + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape, beta1_shape, beta2_shape, epsilon_shape, grad_shape): @@ -4291,7 +4301,7 @@ class FusedSparseAdam(PrimitiveWithInfer): sig.make_sig('beta2', dtype=sig.sig_dtype.T), sig.make_sig('epsilon', dtype=sig.sig_dtype.T), sig.make_sig('grad', dtype=sig.sig_dtype.T), - sig.make_sig('indices', dtype=sig.sig_dtype.T1), + sig.make_sig('indices', dtype=sig.sig_dtype.T1) ) @prim_attr_register @@ -4301,6 +4311,7 @@ class FusedSparseAdam(PrimitiveWithInfer): self.init_prim_io_names(inputs=['var', 'm', 'v', 'beta1_power', 'beta2_power', 'lr', 'beta1', 'beta2', 'epsilon', 'grad', 'indices'], outputs=['var', 'm', 'v']) + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape, beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape): @@ -4429,7 +4440,7 @@ class FusedSparseLazyAdam(PrimitiveWithInfer): sig.make_sig('beta2', dtype=sig.sig_dtype.T), sig.make_sig('epsilon', dtype=sig.sig_dtype.T), sig.make_sig('grad', dtype=sig.sig_dtype.T), - sig.make_sig('indices', dtype=sig.sig_dtype.T1), + sig.make_sig('indices', dtype=sig.sig_dtype.T1) ) @prim_attr_register @@ -4439,6 +4450,7 @@ class FusedSparseLazyAdam(PrimitiveWithInfer): self.init_prim_io_names(inputs=['var', 'm', 'v', 'beta1_power', 'beta2_power', 'lr', 'beta1', 'beta2', 'epsilon', 'grad', 'indices'], outputs=['var', 'm', 'v']) + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape, beta1_shape, beta2_shape, epsilon_shape, grad_shape, indices_shape): @@ -4533,13 +4545,15 @@ class FusedSparseFtrl(PrimitiveWithInfer): sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('linear', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('grad', dtype=sig.sig_dtype.T), - sig.make_sig('indices', dtype=sig.sig_dtype.T1), + sig.make_sig('indices', dtype=sig.sig_dtype.T1) ) @prim_attr_register def __init__(self, lr, l1, l2, lr_power, use_locking=False): self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) + validator.check_value_type("lr", lr, [float], self.name) validator.check_value_type("l1", l1, [float], self.name) validator.check_value_type("l2", l2, [float], self.name) @@ -4642,20 +4656,23 @@ class FusedSparseProximalAdagrad(PrimitiveWithInfer): sig.make_sig('l1', dtype=sig.sig_dtype.T), sig.make_sig('l2', dtype=sig.sig_dtype.T), sig.make_sig('grad', dtype=sig.sig_dtype.T), - sig.make_sig('indices', dtype=sig.sig_dtype.T1), + sig.make_sig('indices', dtype=sig.sig_dtype.T1) ) @prim_attr_register def __init__(self, use_locking=False): self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) - def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): + def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, + grad_shape, indices_shape): validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) return [1], [1] - def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): + def infer_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, + grad_dtype, indices_dtype): args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name) validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, [mstype.float32], self.name) @@ -4932,12 +4949,13 @@ class ApplyAdaMax(PrimitiveWithInfer): sig.make_sig('beta1', dtype=sig.sig_dtype.T3), sig.make_sig('beta2', dtype=sig.sig_dtype.T4), sig.make_sig('epsilon', dtype=sig.sig_dtype.T5), - sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('grad', dtype=sig.sig_dtype.T) ) @prim_attr_register def __init__(self): """Initialize ApplyAdaMax""" + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, lr_shape, beta1_shape, beta2_shape, epsilon_shape, grad_shape): @@ -5061,12 +5079,13 @@ class ApplyAdadelta(PrimitiveWithInfer): sig.make_sig('lr', dtype=sig.sig_dtype.T1), sig.make_sig('rho', dtype=sig.sig_dtype.T2), sig.make_sig('epsilon', dtype=sig.sig_dtype.T3), - sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('grad', dtype=sig.sig_dtype.T) ) @prim_attr_register def __init__(self): """Initialize ApplyAdadelta""" + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, var_shape, accum_shape, accum_update_shape, lr_shape, rho_shape, epsilon_shape, grad_shape): @@ -5166,12 +5185,13 @@ class ApplyAdagrad(PrimitiveWithInfer): sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('lr', dtype=sig.sig_dtype.T1), - sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('grad', dtype=sig.sig_dtype.T) ) @prim_attr_register def __init__(self, update_slots=True): validator.check_value_type("update_slots", update_slots, [bool], self.name) + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, var_shape, accum_shape, lr_shape, grad_shape): validator.check('accum shape', accum_shape, 'var shape', var_shape, Rel.EQ, self.name) @@ -5260,13 +5280,14 @@ class ApplyAdagradV2(PrimitiveWithInfer): sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('lr', dtype=sig.sig_dtype.T1), - sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('grad', dtype=sig.sig_dtype.T) ) @prim_attr_register def __init__(self, epsilon, update_slots=True): validator.check_value_type("epsilon", epsilon, [float], self.name) validator.check_value_type("update_slots", update_slots, [bool], self.name) + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, var_shape, accum_shape, lr_shape, grad_shape): validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) @@ -5353,7 +5374,7 @@ class SparseApplyAdagrad(PrimitiveWithInfer): sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('grad', dtype=sig.sig_dtype.T), - sig.make_sig('indices', dtype=sig.sig_dtype.T1), + sig.make_sig('indices', dtype=sig.sig_dtype.T1) ) @prim_attr_register @@ -5362,6 +5383,7 @@ class SparseApplyAdagrad(PrimitiveWithInfer): validator.check_is_float(lr, "lr", self.name) validator.check_value_type("update_slots", update_slots, [bool], self.name) validator.check_value_type("use_locking", use_locking, [bool], self.name) + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, var_shape, accum_shape, grad_shape, indices_shape): validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) @@ -5450,7 +5472,7 @@ class SparseApplyAdagradV2(PrimitiveWithInfer): sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('grad', dtype=sig.sig_dtype.T), - sig.make_sig('indices', dtype=sig.sig_dtype.T1), + sig.make_sig('indices', dtype=sig.sig_dtype.T1) ) @prim_attr_register @@ -5459,6 +5481,7 @@ class SparseApplyAdagradV2(PrimitiveWithInfer): self.epsilon = validator.check_value_type("epsilon", epsilon, [float], self.name) self.use_locking = validator.check_value_type("update_slots", update_slots, [bool], self.name) self.update_slots = validator.check_value_type("use_locking", use_locking, [bool], self.name) + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, var_shape, accum_shape, grad_shape, indices_shape): validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) @@ -5553,13 +5576,14 @@ class ApplyProximalAdagrad(PrimitiveWithInfer): sig.make_sig('lr', dtype=sig.sig_dtype.T1), sig.make_sig('l1', dtype=sig.sig_dtype.T2), sig.make_sig('l2', dtype=sig.sig_dtype.T3), - sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('grad', dtype=sig.sig_dtype.T) ) @prim_attr_register def __init__(self, use_locking=False): self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad'], outputs=['var', 'accum']) + self.add_prim_attr('side_effect_mem', True) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape): @@ -5672,19 +5696,22 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck): sig.make_sig('l1', dtype=sig.sig_dtype.T2), sig.make_sig('l2', dtype=sig.sig_dtype.T3), sig.make_sig('grad', dtype=sig.sig_dtype.T), - sig.make_sig('indices', dtype=sig.sig_dtype.T4), + sig.make_sig('indices', dtype=sig.sig_dtype.T4) ) @prim_attr_register def __init__(self, use_locking=False): self.init_prim_io_names(inputs=['var', 'accum', 'lr', 'l1', 'l2', 'grad', 'indices'], outputs=['var', 'accum']) + self.add_prim_attr('side_effect_mem', True) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) - def check_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, grad_shape, indices_shape): + def check_shape(self, var_shape, accum_shape, lr_shape, l1_shape, l2_shape, + grad_shape, indices_shape): validator.check_int(len(indices_shape), 1, Rel.EQ, "indices rank", self.name) - def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, grad_dtype, indices_dtype): + def check_dtype(self, var_dtype, accum_dtype, lr_dtype, l1_dtype, l2_dtype, + grad_dtype, indices_dtype): args = {'var': var_dtype, 'accum': accum_dtype, 'grad': grad_dtype} validator.check_tensors_dtypes_same_and_valid(args, [mstype.float16, mstype.float32], self.name) validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name) @@ -5772,14 +5799,16 @@ class ApplyAddSign(PrimitiveWithInfer): sig.make_sig('alpha', dtype=sig.sig_dtype.T2), sig.make_sig('sign_decay', dtype=sig.sig_dtype.T3), sig.make_sig('beta', dtype=sig.sig_dtype.T3), - sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('grad', dtype=sig.sig_dtype.T) ) @prim_attr_register def __init__(self): "Initialize ApplyAddSign" + self.add_prim_attr('side_effect_mem', True) - def infer_shape(self, var_shape, m_shape, lr_shape, alpha_shape, sign_decay_shape, beta_shape, grad_shape): + def infer_shape(self, var_shape, m_shape, lr_shape, alpha_shape, sign_decay_shape, + beta_shape, grad_shape): validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name) validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name) lr_shape_len = len(lr_shape) @@ -5800,7 +5829,8 @@ class ApplyAddSign(PrimitiveWithInfer): validator.check_int(beta_shape[0], 1, Rel.EQ, "beta_shape[0]", self.name) return var_shape, m_shape - def infer_dtype(self, var_dtype, m_dtype, lr_dtype, alpha_dtype, sign_decay_dtype, beta_dtype, grad_dtype): + def infer_dtype(self, var_dtype, m_dtype, lr_dtype, alpha_dtype, sign_decay_dtype, + beta_dtype, grad_dtype): valid_dtypes = [mstype.float16, mstype.float32] args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype} validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) @@ -5892,14 +5922,16 @@ class ApplyPowerSign(PrimitiveWithInfer): sig.make_sig('logbase', dtype=sig.sig_dtype.T), sig.make_sig('sign_decay', dtype=sig.sig_dtype.T), sig.make_sig('beta', dtype=sig.sig_dtype.T), - sig.make_sig('grad', dtype=sig.sig_dtype.T), + sig.make_sig('grad', dtype=sig.sig_dtype.T) ) @prim_attr_register def __init__(self): "Initialize ApplyPowerSign" + self.add_prim_attr('side_effect_mem', True) - def infer_shape(self, var_shape, m_shape, lr_shape, logbase_shape, sign_decay_shape, beta_shape, grad_shape): + def infer_shape(self, var_shape, m_shape, lr_shape, logbase_shape, sign_decay_shape, + beta_shape, grad_shape): validator.check('m_shape', m_shape, 'var_shape', var_shape, Rel.EQ, self.name) validator.check('grad_shape', grad_shape, 'var_shape', var_shape, Rel.EQ, self.name) lr_shape_len = len(lr_shape) @@ -5920,7 +5952,8 @@ class ApplyPowerSign(PrimitiveWithInfer): validator.check_int(beta_shape[0], 1, Rel.EQ, "beta_shape[0]", self.name) return var_shape, m_shape - def infer_dtype(self, var_dtype, m_dtype, lr_dtype, logbase_dtype, sign_decay_dtype, beta_dtype, grad_dtype): + def infer_dtype(self, var_dtype, m_dtype, lr_dtype, logbase_dtype, sign_decay_dtype, + beta_dtype, grad_dtype): valid_dtypes = [mstype.float16, mstype.float32] args = {'var': var_dtype, 'm': m_dtype, 'grad': grad_dtype} validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) @@ -5979,12 +6012,13 @@ class ApplyGradientDescent(PrimitiveWithInfer): __mindspore_signature__ = ( sig.make_sig('var', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('alpha', dtype=sig.sig_dtype.T1), - sig.make_sig('delta', dtype=sig.sig_dtype.T), + sig.make_sig('delta', dtype=sig.sig_dtype.T) ) @prim_attr_register def __init__(self): "Initialize ApplyGradientDescent" + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, var_shape, alpha_shape, delta_shape): validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name) @@ -6060,12 +6094,13 @@ class ApplyProximalGradientDescent(PrimitiveWithInfer): sig.make_sig('alpha', dtype=sig.sig_dtype.T1), sig.make_sig('l1', dtype=sig.sig_dtype.T2), sig.make_sig('l2', dtype=sig.sig_dtype.T3), - sig.make_sig('delta', dtype=sig.sig_dtype.T), + sig.make_sig('delta', dtype=sig.sig_dtype.T) ) @prim_attr_register def __init__(self): "Initialize ApplyGradientDescent" + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, var_shape, alpha_shape, l1_shape, l2_shape, delta_shape): validator.check('delta shape', delta_shape, 'var shape', var_shape, Rel.EQ, self.name) @@ -6246,6 +6281,7 @@ class ApplyFtrl(PrimitiveWithInfer): def __init__(self, use_locking=False): self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape, @@ -6254,7 +6290,8 @@ class ApplyFtrl(PrimitiveWithInfer): validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) return var_shape - def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type): + def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, + lr_power_type): valid_dtypes = [mstype.float16, mstype.float32] args = {'var': var_type, 'accum': accum_type, 'linear': linear_type, 'grad': grad_type} validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) @@ -6336,7 +6373,7 @@ class SparseApplyFtrl(PrimitiveWithCheck): sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('linear', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('grad', dtype=sig.sig_dtype.T), - sig.make_sig('indices', dtype=sig.sig_dtype.T1), + sig.make_sig('indices', dtype=sig.sig_dtype.T1) ) @prim_attr_register @@ -6352,6 +6389,7 @@ class SparseApplyFtrl(PrimitiveWithCheck): self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'indices'], outputs=['var', 'accum', 'linear']) + self.add_prim_attr('side_effect_mem', True) def check_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) @@ -6443,7 +6481,7 @@ class SparseApplyFtrlV2(PrimitiveWithInfer): sig.make_sig('accum', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('linear', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), sig.make_sig('grad', dtype=sig.sig_dtype.T), - sig.make_sig('indices', dtype=sig.sig_dtype.T1), + sig.make_sig('indices', dtype=sig.sig_dtype.T1) ) @prim_attr_register @@ -6458,6 +6496,7 @@ class SparseApplyFtrlV2(PrimitiveWithInfer): self.lr_power = validator.check_number("lr_power", lr_power, 0, Rel.LE, self.name) self.l2_shrinkage = validator.check_value_type("l2_shrinkage", l2_shrinkage, [float], self.name) self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape): validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 0e70e8823f..27255aa2dc 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -15,6 +15,7 @@ """Other operators.""" import functools +from mindspore.common import monad from .. import signature as sig from ..._checkparam import Validator as validator, Rel from ...common import dtype as mstype @@ -58,12 +59,14 @@ class Assign(PrimitiveWithCheck): """ __mindspore_signature__ = ( sig.make_sig('variable', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T), - sig.make_sig('value', dtype=sig.sig_dtype.T) + sig.make_sig('value', dtype=sig.sig_dtype.T), + sig.make_sig('u', default=monad.U, dtype=sig.sig_dtype.T1) ) @prim_attr_register def __init__(self): self.init_prim_io_names(inputs=['ref', 'value'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) def check_dtype(self, variable, value): types = mstype.number_type + (mstype.bool_,) @@ -108,6 +111,28 @@ class InplaceAssign(PrimitiveWithInfer): def infer_dtype(self, x, y, z): return z +class Load(PrimitiveWithCheck): + """ + Load `Parameter` to a value. + + Inputs: + - **variable** (Parameter) - The `Parameter`. + + Outputs: + Tensor - The loaded parameter tensor value. + """ + __mindspore_signature__ = ( + sig.make_sig('variable', sig.sig_rw.RW_READ, dtype=sig.sig_dtype.T), + sig.make_sig('u', dtype=sig.sig_dtype.T1) + ) + + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['ref', 'u'], outputs=['output']) + + def check_dtype(self, variable): + if variable != mstype.type_refkey: + validator.check_tensor_type_same({"variable": variable}, mstype.number_type, self.name) class BoundingBoxEncode(PrimitiveWithInfer): """ @@ -354,9 +379,12 @@ class Partial(Primitive): FunctionType, partial function binded with arguments. """ + # Side effect will propagated from the first argument to return value. + side_effect_propagate = 1 + @prim_attr_register def __init__(self): - pass + self.add_prim_attr('side_effect_propagate', 1) def __call__(self, *args): func = args[0].__call__ @@ -419,13 +447,34 @@ class Depend(Primitive): [0.2 0.2 0.2 0.2 0.2]] """ + # Side effect will propagated from the first argument to return value. + side_effect_propagate = 1 + @prim_attr_register def __init__(self): - pass + self.add_prim_attr('side_effect_propagate', 1) def __call__(self, value, expr): return value +class UpdateState(Primitive): + """ + UpdateState is used for update side-effect state. + + Inputs: + - **value** (State) - the state value to be updated. + - **expr** (Expression) - the expression to evaluate before state changes. + + Outputs: + State, the updated state value. + """ + + @prim_attr_register + def __init__(self): + pass + + def __call__(self, state, expr): + return state class CheckBprop(PrimitiveWithInfer): """ @@ -647,9 +696,12 @@ class identity(Primitive): The same as input. """ + # Side effect will propagated from the first argument to return value. + side_effect_propagate = 1 + @prim_attr_register def __init__(self): - pass + self.add_prim_attr('side_effect_propagate', 1) def __call__(self, x): return x diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index e7cf55d601..14f9a3dd2c 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -50,6 +50,7 @@ class StandardNormal(PrimitiveWithInfer): def __init__(self, seed=0, seed2=0): """Initialize StandardNormal""" self.init_prim_io_names(inputs=['shape'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) Validator.check_non_negative_int(seed, "seed", self.name) Validator.check_non_negative_int(seed2, "seed2", self.name) @@ -101,6 +102,7 @@ class StandardLaplace(PrimitiveWithInfer): def __init__(self, seed=0, seed2=0): """Initialize StandardLaplace""" self.init_prim_io_names(inputs=['shape'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) Validator.check_value_type('seed', seed, [int], self.name) Validator.check_value_type('seed2', seed2, [int], self.name) @@ -158,6 +160,7 @@ class Gamma(PrimitiveWithInfer): def __init__(self, seed=0, seed2=0): """Initialize Gamma""" self.init_prim_io_names(inputs=['shape', 'alpha', 'beta'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) Validator.check_non_negative_int(seed, "seed", self.name) Validator.check_non_negative_int(seed2, "seed2", self.name) @@ -216,6 +219,7 @@ class Poisson(PrimitiveWithInfer): def __init__(self, seed=0, seed2=0): """Initialize Poisson""" self.init_prim_io_names(inputs=['shape', 'mean'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) Validator.check_non_negative_int(seed, "seed", self.name) Validator.check_non_negative_int(seed2, "seed2", self.name) @@ -278,6 +282,7 @@ class UniformInt(PrimitiveWithInfer): def __init__(self, seed=0, seed2=0): """Initialize UniformInt""" self.init_prim_io_names(inputs=['shape', 'minval', 'maxval'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) Validator.check_non_negative_int(seed, "seed", self.name) Validator.check_non_negative_int(seed2, "seed2", self.name) @@ -331,6 +336,7 @@ class UniformReal(PrimitiveWithInfer): def __init__(self, seed=0, seed2=0): """Initialize UniformReal""" self.init_prim_io_names(inputs=['shape'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) Validator.check_non_negative_int(seed, "seed", self.name) Validator.check_non_negative_int(seed2, "seed2", self.name) @@ -394,6 +400,7 @@ class RandomChoiceWithMask(PrimitiveWithInfer): Validator.check_positive_int(count, "count", self.name) Validator.check_value_type('seed', seed, [int], self.name) Validator.check_value_type('seed2', seed2, [int], self.name) + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, x_shape): Validator.check_int(len(x_shape), 1, Rel.GE, "input_x rank", self.name) @@ -450,6 +457,7 @@ class RandomCategorical(PrimitiveWithInfer): Validator.check_type_name("dtype", dtype, valid_values, self.name) self.init_prim_io_names(inputs=['logits', 'num_samples', 'seed'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) def __infer__(self, logits, num_samples, seed): logits_dtype = logits['dtype'] @@ -508,6 +516,7 @@ class Multinomial(PrimitiveWithInfer): Validator.check_non_negative_int(seed, "seed", self.name) Validator.check_non_negative_int(seed2, "seed2", self.name) self.init_prim_io_names(inputs=['input', 'num_sample'], outputs=['output']) + self.add_prim_attr('side_effect_mem', True) def __infer__(self, inputs, num_samples): input_shape = inputs["shape"] diff --git a/model_zoo/official/cv/centerface/src/centerface.py b/model_zoo/official/cv/centerface/src/centerface.py index 4d0de21d49..0b72a835fd 100644 --- a/model_zoo/official/cv/centerface/src/centerface.py +++ b/model_zoo/official/cv/centerface/src/centerface.py @@ -156,7 +156,7 @@ class CenterfaceMobilev2(nn.Cell): class CenterFaceLoss(nn.Cell): """ - Loss method defination. + Loss method definition. """ def __init__(self, wh_weight, reg_offset, off_weight, hm_weight, lm_weight): super(CenterFaceLoss, self).__init__() @@ -260,8 +260,10 @@ class TrainingWrapper(nn.Cell): # init overflow buffer init = self.alloc_status() + init = F.depend(init, loss) # clear overflow buffer - self.clear_status(init) + clear_status = self.clear_status(init) + loss = F.depend(loss, clear_status) #sens = sens_input #P.Fill()(P.DType()(loss), P.Shape()(loss), sens_input) # user can contral loss scale by add a sens_input sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) @@ -272,7 +274,9 @@ class TrainingWrapper(nn.Cell): grads = self.grad_reducer(grads) # get the overflow buffer - self.get_status(init) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) # sum overflow buffer elements, 0:not overflow , >0:overflow flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: diff --git a/model_zoo/official/cv/resnet/src/resnet_gpu_benchmark.py b/model_zoo/official/cv/resnet/src/resnet_gpu_benchmark.py index 81fbe042a7..a164473cf5 100644 --- a/model_zoo/official/cv/resnet/src/resnet_gpu_benchmark.py +++ b/model_zoo/official/cv/resnet/src/resnet_gpu_benchmark.py @@ -123,6 +123,9 @@ class ResidualBlock(nn.Cell): def construct(self, x): identity = x + if self.down_sample: + identity = self.down_sample_layer(identity) + out = self.conv1(x) out = self.bn1(out) out = self.relu(out) @@ -132,10 +135,7 @@ class ResidualBlock(nn.Cell): out = self.conv3(out) out = self.bn3(out) - if self.down_sample: - identity = self.down_sample_layer(identity) - - out = self.add(out, identity) + out = self.add(identity, out) out = self.relu(out) return out diff --git a/model_zoo/official/cv/resnet_thor/src/thor.py b/model_zoo/official/cv/resnet_thor/src/thor.py index 3fffd2c709..9a50545b8e 100644 --- a/model_zoo/official/cv/resnet_thor/src/thor.py +++ b/model_zoo/official/cv/resnet_thor/src/thor.py @@ -136,8 +136,6 @@ class THOR_GPU(Optimizer): g = self.reshape(g, (g_shape[0], -1)) matrix_A = self.matrix_A[i] matrix_G = self.matrix_G[i] - matrix_A = F.depend(matrix_A, g) - matrix_G = F.depend(matrix_G, g) g = self.update_gradient(matrix_G, g, matrix_A) if i == 53: new_grads = new_grads + (g,) @@ -284,9 +282,6 @@ class THOR(Optimizer): matrix_A = self.matrix_A[i] matrix_G = self.matrix_G[i] matrix_max = self.matrix_max_inv[i] - matrix_A = F.depend(matrix_A, g) - matrix_G = F.depend(matrix_G, g) - matrix_max = F.depend(matrix_max, g) if i == 53: g = self.cube_matmul_left_fc(matrix_G, g) g = self.cube_matmul_right_fc(g, matrix_A, matrix_max) diff --git a/model_zoo/official/cv/resnet_thor/src/thor_layer.py b/model_zoo/official/cv/resnet_thor/src/thor_layer.py index 96746e1fa9..881c8da5c6 100644 --- a/model_zoo/official/cv/resnet_thor/src/thor_layer.py +++ b/model_zoo/official/cv/resnet_thor/src/thor_layer.py @@ -586,7 +586,6 @@ class Conv2d_Thor(_Conv): matrix_A_inv = self.reshape(matrix_A_inv, self.matrix_A_device_temp_shape) matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3)) self.matrix_A_inv = matrix_A_inv - self.matrix_G_inv = self.fake_G out = self.conv2d(x, self.weight) out = self.getG(out) else: @@ -751,7 +750,6 @@ class Dense_Thor(Cell): matrix_A_inv = self.transpose(matrix_A_inv, (2, 0, 1, 3)) matrix_A_inv = self.cast(matrix_A_inv, mstype.float16) self.matrix_A_inv = matrix_A_inv - self.matrix_G_inv = self.fake_G output = self.matmul(x, self.weight) output = self.getG(output) else: diff --git a/model_zoo/official/nlp/bert/src/bert_for_finetune.py b/model_zoo/official/nlp/bert/src/bert_for_finetune.py index 03bd3a1535..210339ccd0 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_finetune.py +++ b/model_zoo/official/nlp/bert/src/bert_for_finetune.py @@ -91,9 +91,8 @@ class BertFinetuneCell(nn.Cell): else: self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() + self.clear_status = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() @@ -123,9 +122,9 @@ class BertFinetuneCell(nn.Cell): if not self.gpu_target: init = self.alloc_status() - clear_before_grad = self.clear_before_grad(init) - F.control_depend(loss, init) - self.depend_parameter_use(clear_before_grad, scaling_sens) + init = F.depend(init, loss) + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -137,10 +136,10 @@ class BertFinetuneCell(nn.Cell): if self.reducer_flag: grads = self.grad_reducer(grads) if not self.gpu_target: - flag = self.get_status(init) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) flag_sum = self.reduce_sum(init, (0,)) - F.control_depend(grads, flag) - F.control_depend(flag, flag_sum) else: flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) flag_sum = self.addn(flag_sum) @@ -185,9 +184,8 @@ class BertSquadCell(nn.Cell): self.cast = P.Cast() self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() + self.clear_status = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() @@ -219,6 +217,9 @@ class BertSquadCell(nn.Cell): scaling_sens = self.loss_scale else: scaling_sens = sens + init = F.depend(init, loss) + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -228,22 +229,19 @@ class BertSquadCell(nn.Cell): is_impossible, self.cast(scaling_sens, mstype.float32)) - clear_before_grad = self.clear_before_grad(init) - F.control_depend(loss, init) - self.depend_parameter_use(clear_before_grad, scaling_sens) grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: grads = self.grad_reducer(grads) - flag = self.get_status(init) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: flag_reduce = self.allreduce(flag_sum) cond = self.less_equal(self.base, flag_reduce) else: cond = self.less_equal(self.base, flag_sum) - F.control_depend(grads, flag) - F.control_depend(flag, flag_sum) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) diff --git a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py index ddb7532a4b..5113ea5927 100644 --- a/model_zoo/official/nlp/bert/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert/src/bert_for_pre_training.py @@ -69,6 +69,7 @@ class GetMaskedLMOutput(nn.Cell): Returns: Tensor, masked lm output. """ + def __init__(self, config): super(GetMaskedLMOutput, self).__init__() self.width = config.hidden_size @@ -126,6 +127,7 @@ class GetNextSentenceOutput(nn.Cell): Returns: Tensor, next sentence output. """ + def __init__(self, config): super(GetNextSentenceOutput, self).__init__() self.log_softmax = P.LogSoftmax() @@ -154,6 +156,7 @@ class BertPreTraining(nn.Cell): Returns: Tensor, prediction_scores, seq_relationship_score. """ + def __init__(self, config, is_training, use_one_hot_embeddings): super(BertPreTraining, self).__init__() self.bert = BertModel(config, is_training, use_one_hot_embeddings) @@ -181,6 +184,7 @@ class BertPretrainingLoss(nn.Cell): Returns: Tensor, total loss. """ + def __init__(self, config): super(BertPretrainingLoss, self).__init__() self.vocab_size = config.vocab_size @@ -231,6 +235,7 @@ class BertNetworkWithLoss(nn.Cell): Returns: Tensor, the loss of the network. """ + def __init__(self, config, is_training, use_one_hot_embeddings=False): super(BertNetworkWithLoss, self).__init__() self.bert = BertPreTraining(config, is_training, use_one_hot_embeddings) @@ -265,6 +270,7 @@ class BertTrainOneStepCell(nn.TrainOneStepCell): optimizer (Optimizer): Optimizer for updating the weights. sens (Number): The adjust parameter. Default: 1.0. """ + def __init__(self, network, optimizer, sens=1.0): super(BertTrainOneStepCell, self).__init__(network, optimizer, sens) self.cast = P.Cast() @@ -336,6 +342,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): optimizer (Optimizer): Optimizer for updating the weights. scale_update_cell (Cell): Cell to do the loss scale. Default: None. """ + def __init__(self, network, optimizer, scale_update_cell=None): super(BertTrainOneStepWithLossScaleCell, self).__init__(auto_prefix=False) self.network = network @@ -365,9 +372,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): self.gpu_target = False self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() + self.clear_status = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() @@ -376,7 +382,6 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): if scale_update_cell: self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) - @C.add_flags(has_effect=True) def construct(self, input_ids, input_mask, @@ -403,7 +408,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): if not self.gpu_target: # alloc status and clear should be right before gradoperation init = self.alloc_status() - self.clear_before_grad(init) + init = F.depend(init, loss) + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -418,7 +425,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if not self.gpu_target: - self.get_status(init) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) flag_sum = self.reduce_sum(init, (0,)) else: flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) @@ -483,7 +492,7 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell): self.gpu_target = False self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() + self.clear_status = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) @@ -494,7 +503,6 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell): if scale_update_cell: self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) - @C.add_flags(has_effect=True) def construct(self, input_ids, input_mask, @@ -521,7 +529,10 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell): if not self.gpu_target: # alloc status and clear should be right before gradoperation init = self.alloc_status() - self.clear_before_grad(init) + init = F.depend(init, loss) + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) + grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -536,7 +547,9 @@ class BertTrainOneStepWithLossScaleCellForAdam(nn.Cell): grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if not self.gpu_target: - self.get_status(init) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) flag_sum = self.reduce_sum(init, (0,)) else: flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) @@ -607,6 +620,7 @@ class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell): accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = batch_size * accumulation_steps. Default: 1. """ + def __init__(self, network, optimizer, scale_update_cell=None, accumulation_steps=1, enable_global_norm=False): super(BertTrainAccumulationAllReducePostWithLossScaleCell, self).__init__(auto_prefix=False) self.network = network @@ -639,7 +653,7 @@ class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell): self.cast = P.Cast() self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() + self.clear_status = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() @@ -653,7 +667,6 @@ class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell): if scale_update_cell: self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) - @C.add_flags(has_effect=True) def construct(self, input_ids, input_mask, @@ -676,7 +689,11 @@ class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell): scaling_sens = self.loss_scale else: scaling_sens = sens - + # alloc status and clear should be right before gradoperation + init = self.alloc_status() + init = F.depend(loss, init) + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) # update accumulation parameters is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) self.local_step = self.select(is_accu_step, self.local_step + self.one, self.one) @@ -684,9 +701,6 @@ class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell): mean_loss = self.accu_loss / self.local_step is_accu_step = self.not_equal(self.local_step, self.accumulation_steps) - # alloc status and clear should be right before gradoperation - init = self.alloc_status() - self.clear_before_grad(init) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -700,7 +714,9 @@ class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell): accu_succ = self.hyper_map(accumulate_accu_grads, self.accu_grads, grads) mean_loss = F.depend(mean_loss, accu_succ) - self.get_status(init) + init = F.depend(init, mean_loss) + get_status = self.get_status(init) + init = F.depend(init, get_status) flag_sum = self.reduce_sum(init, (0,)) overflow = self.less_equal(self.base, flag_sum) overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) @@ -718,8 +734,8 @@ class BertTrainAccumulationAllReducePostWithLossScaleCell(nn.Cell): grads = C.clip_by_global_norm(grads, 1.0, None) else: grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) + accu_overflow = F.depend(accu_overflow, grads) accu_overflow = self.overflow_reducer(accu_overflow) - F.control_depend(grads, accu_overflow) overflow = self.less_equal(self.base, accu_overflow) accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) overflow = F.depend(overflow, accu_succ) diff --git a/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py b/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py index b336064b48..296e8f303a 100644 --- a/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py +++ b/model_zoo/official/nlp/bert_thor/src/bert_for_pre_training.py @@ -369,9 +369,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): self.cast = P.Cast() self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() + self.clear_status = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() @@ -380,7 +379,6 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): if scale_update_cell: self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) - @C.add_flags(has_effect=True) def construct(self, input_ids, input_mask, @@ -405,7 +403,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): scaling_sens = sens # alloc status and clear should be right before gradoperation init = self.alloc_status() - self.clear_before_grad(init) + init = F.depend(init, loss) + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -419,7 +419,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - self.get_status(init) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: # sum overflow flag over devices diff --git a/model_zoo/official/nlp/bert_thor/src/thor_for_bert.py b/model_zoo/official/nlp/bert_thor/src/thor_for_bert.py index 51a436e12d..00ab4c45ca 100644 --- a/model_zoo/official/nlp/bert_thor/src/thor_for_bert.py +++ b/model_zoo/official/nlp/bert_thor/src/thor_for_bert.py @@ -236,8 +236,6 @@ class THOR(Optimizer): matrix_idx = em_idx temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - temp_a = F.depend(temp_a, g) - temp_g = F.depend(temp_g, g) temp_a = self.expand(temp_a, 1) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) @@ -299,8 +297,6 @@ class THOR(Optimizer): matrix_idx = 6 * i + offset_idx + 3 temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - temp_a = F.depend(temp_a, g) - temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) @@ -317,8 +313,6 @@ class THOR(Optimizer): pooler_bias = gradients[pooler_layer_idx + 1] temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - temp_a = F.depend(temp_a, g) - temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) @@ -334,8 +328,6 @@ class THOR(Optimizer): mlm_bias = gradients[mlm_fc_idx + 1] temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - temp_a = F.depend(temp_a, g) - temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) diff --git a/model_zoo/official/nlp/bert_thor/src/thor_for_bert_arg.py b/model_zoo/official/nlp/bert_thor/src/thor_for_bert_arg.py index fcb2638283..911f99f5a2 100644 --- a/model_zoo/official/nlp/bert_thor/src/thor_for_bert_arg.py +++ b/model_zoo/official/nlp/bert_thor/src/thor_for_bert_arg.py @@ -242,8 +242,6 @@ class THOR(Optimizer): matrix_idx = em_idx temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - temp_a = F.depend(temp_a, g) - temp_g = F.depend(temp_g, g) temp_a = self.expand(temp_a, 1) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) @@ -305,8 +303,6 @@ class THOR(Optimizer): matrix_idx = 6 * i + offset_idx + 3 temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - temp_a = F.depend(temp_a, g) - temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) @@ -323,8 +319,6 @@ class THOR(Optimizer): pooler_bias = gradients[pooler_layer_idx + 1] temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - temp_a = F.depend(temp_a, g) - temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) @@ -340,8 +334,6 @@ class THOR(Optimizer): mlm_bias = gradients[mlm_fc_idx + 1] temp_a = self.matrix_A[matrix_idx] temp_g = self.matrix_G[matrix_idx] - temp_a = F.depend(temp_a, g) - temp_g = F.depend(temp_g, g) temp_a = self.cast(temp_a, mstype.float16) temp_g = self.cast(temp_g, mstype.float16) g = self.cast(g, mstype.float16) diff --git a/model_zoo/official/nlp/bert_thor/src/thor_layer.py b/model_zoo/official/nlp/bert_thor/src/thor_layer.py index 8e37c72d0a..eafe06b1e5 100644 --- a/model_zoo/official/nlp/bert_thor/src/thor_layer.py +++ b/model_zoo/official/nlp/bert_thor/src/thor_layer.py @@ -133,7 +133,6 @@ class Embedding_Thor(Cell): matrix_A_inv = self.inv(matrix_A) matrix_A_inv = self.cast(matrix_A_inv, mstype.float16) self.matrix_A_inv = matrix_A_inv - self.matrix_G_inv = self.fake_G output_for_reshape = self.gather(self.embedding_table, flat_ids, 0) output_for_reshape = self.getG(output_for_reshape) else: @@ -253,7 +252,6 @@ class Dense_Thor(Cell): matrix_A_inv = self.matrix_combine(matrix_A_inv) matrix_A_inv = self.cast(matrix_A_inv, mstype.float16) self.matrix_A_inv = matrix_A_inv - self.matrix_G_inv = self.fake_G output = self.matmul(x, self.weight) output = self.getG(output) else: diff --git a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_train.py b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_train.py index c96d158376..fa7f0175b1 100644 --- a/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_train.py +++ b/model_zoo/official/nlp/gnmt_v2/src/gnmt_model/gnmt_for_train.py @@ -243,9 +243,8 @@ class GNMTTrainOneStepWithLossScaleCell(nn.Cell): self.cast = P.Cast() self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() + self.clear_status = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() @@ -254,7 +253,6 @@ class GNMTTrainOneStepWithLossScaleCell(nn.Cell): self.loss_scaling_manager = scale_update_cell if scale_update_cell: self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) - self.add_flags(has_effect=True) self.loss_scalar = P.ScalarSummary() @@ -291,14 +289,16 @@ class GNMTTrainOneStepWithLossScaleCell(nn.Cell): target_ids, label_ids, label_weights) - # Alloc status. - init = self.alloc_status() - # Clear overflow buffer. - self.clear_before_grad(init) if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens + # Alloc status. + init = self.alloc_status() + # Clear overflow buffer. + init = F.depend(init, loss) + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) grads = self.grad(self.network, weights)(source_ids, source_mask, target_ids, @@ -312,7 +312,9 @@ class GNMTTrainOneStepWithLossScaleCell(nn.Cell): if self.reducer_flag: # Apply grad reducer on grads. grads = self.grad_reducer(grads) - self.get_status(init) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: diff --git a/model_zoo/official/nlp/gnmt_v2/src/utils/optimizer.py b/model_zoo/official/nlp/gnmt_v2/src/utils/optimizer.py index 46981a7c1d..1a0206b72d 100644 --- a/model_zoo/official/nlp/gnmt_v2/src/utils/optimizer.py +++ b/model_zoo/official/nlp/gnmt_v2/src/utils/optimizer.py @@ -414,7 +414,5 @@ class AdamWeightDecayDynamicLR(Optimizer): self.params, self.moments1, self.moments2, gradients, self.decay_flag) added_global_step = self.global_step + self.one - F.control_depend(lr, added_global_step) self.global_step = added_global_step - return updated_velocity diff --git a/model_zoo/official/nlp/gpt/src/gpt_wrapcell.py b/model_zoo/official/nlp/gpt/src/gpt_wrapcell.py index 59858d3bc6..b995daf283 100644 --- a/model_zoo/official/nlp/gpt/src/gpt_wrapcell.py +++ b/model_zoo/official/nlp/gpt/src/gpt_wrapcell.py @@ -97,9 +97,8 @@ class GPTTrainOneStepWithLossScaleCell(nn.Cell): self.cast = P.Cast() self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() + self.clear_status = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() @@ -108,7 +107,6 @@ class GPTTrainOneStepWithLossScaleCell(nn.Cell): if scale_update_cell: self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) - @C.add_flags(has_effect=True) def construct(self, input_ids, past=None, @@ -124,7 +122,9 @@ class GPTTrainOneStepWithLossScaleCell(nn.Cell): scaling_sens = sens # alloc status and clear should be right before gradoperation init = self.alloc_status() - self.clear_before_grad(init) + init = F.depend(init, loss) + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) grads = self.grad(self.network, weights)(input_ids, past, self.cast(scaling_sens, @@ -138,7 +138,9 @@ class GPTTrainOneStepWithLossScaleCell(nn.Cell): else: grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - self.get_status(init) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: # sum overflow flag over devices diff --git a/model_zoo/official/nlp/mass/src/transformer/transformer_for_train.py b/model_zoo/official/nlp/mass/src/transformer/transformer_for_train.py index da683752a5..24fab3ddd6 100644 --- a/model_zoo/official/nlp/mass/src/transformer/transformer_for_train.py +++ b/model_zoo/official/nlp/mass/src/transformer/transformer_for_train.py @@ -269,7 +269,6 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): self.get_status = P.NPUGetFloatStatus() self.clear_status = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() @@ -278,7 +277,6 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): self.loss_scaling_manager = scale_update_cell if scale_update_cell: self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) - self.add_flags(has_effect=True) def construct(self, source_eos_ids, @@ -318,16 +316,19 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): label_ids, label_weights) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens + init = False if not self.gpu_target: # init overflow buffer init = self.alloc_status() + init = F.depend(init, loss) # clear overflow buffer - self.clear_status(init) - if sens is None: - scaling_sens = self.loss_scale - else: - scaling_sens = sens + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) grads = self.grad(self.network, weights)(source_ids, source_mask, @@ -347,7 +348,9 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): # get the overflow buffer if not self.gpu_target: - self.get_status(init) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) # sum overflow buffer elements, 0:not overflow , >0:overflow flag_sum = self.reduce_sum(init, (0,)) else: diff --git a/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py b/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py index f8a1bb2263..6a90a271f3 100644 --- a/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py +++ b/model_zoo/official/nlp/tinybert/src/tinybert_for_gd_td.py @@ -234,9 +234,8 @@ class BertTrainWithLossScaleCell(nn.Cell): self.cast = P.Cast() self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() + self.clear_status = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() @@ -245,7 +244,6 @@ class BertTrainWithLossScaleCell(nn.Cell): if scale_update_cell: self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) - @C.add_flags(has_effect=True) def construct(self, input_ids, input_mask, @@ -262,7 +260,9 @@ class BertTrainWithLossScaleCell(nn.Cell): scaling_sens = sens # alloc status and clear should be right before gradoperation init = self.alloc_status() - self.clear_before_grad(init) + init = F.depend(init, loss) + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -272,7 +272,9 @@ class BertTrainWithLossScaleCell(nn.Cell): grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - self.get_status(init) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: # sum overflow flag over devices @@ -463,7 +465,7 @@ class BertNetworkWithLoss_td(nn.Cell): class BertEvaluationWithLossScaleCell(nn.Cell): """ - Especifically defined for finetuning where only four inputs tensor are needed. + Especially defined for finetuning where only four inputs tensor are needed. """ def __init__(self, network, optimizer, scale_update_cell=None): super(BertEvaluationWithLossScaleCell, self).__init__(auto_prefix=False) @@ -487,9 +489,8 @@ class BertEvaluationWithLossScaleCell(nn.Cell): self.cast = P.Cast() self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() + self.clear_status = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() @@ -498,7 +499,6 @@ class BertEvaluationWithLossScaleCell(nn.Cell): if scale_update_cell: self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) - @C.add_flags(has_effect=True) def construct(self, input_ids, input_mask, @@ -517,7 +517,9 @@ class BertEvaluationWithLossScaleCell(nn.Cell): scaling_sens = sens # alloc status and clear should be right before gradoperation init = self.alloc_status() - self.clear_before_grad(init) + init = F.depend(init, loss) + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -528,7 +530,9 @@ class BertEvaluationWithLossScaleCell(nn.Cell): grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - self.get_status(init) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: # sum overflow flag over devices @@ -549,7 +553,7 @@ class BertEvaluationWithLossScaleCell(nn.Cell): class BertEvaluationCell(nn.Cell): """ - Especifically defined for finetuning where only four inputs tensor are needed. + Especially defined for finetuning where only four inputs tensor are needed. """ def __init__(self, network, optimizer, sens=1.0): super(BertEvaluationCell, self).__init__(auto_prefix=False) diff --git a/model_zoo/official/nlp/transformer/src/transformer_for_train.py b/model_zoo/official/nlp/transformer/src/transformer_for_train.py index 4a7f083f71..833beaf784 100644 --- a/model_zoo/official/nlp/transformer/src/transformer_for_train.py +++ b/model_zoo/official/nlp/transformer/src/transformer_for_train.py @@ -281,9 +281,8 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): self.gpu_target = False self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() + self.clear_status = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() @@ -293,7 +292,6 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): if scale_update_cell: self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) - @C.add_flags(has_effect=True) def construct(self, source_eos_ids, source_eos_mask, @@ -317,16 +315,18 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): target_mask, label_ids, label_weights) + if sens is None: + scaling_sens = self.loss_scale + else: + scaling_sens = sens init = False if not self.gpu_target: # alloc status init = self.alloc_status() # clear overflow buffer - self.clear_before_grad(init) - if sens is None: - scaling_sens = self.loss_scale - else: - scaling_sens = sens + init = F.depend(init, loss) + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) grads = self.grad(self.network, weights)(source_ids, source_mask, target_ids, @@ -343,7 +343,9 @@ class TransformerTrainOneStepWithLossScaleCell(nn.Cell): grads = self.grad_reducer(grads) if not self.gpu_target: - self.get_status(init) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) # sum overflow buffer elements, 0: not overflow, >0: overflow flag_sum = self.reduce_sum(init, (0,)) else: diff --git a/model_zoo/research/cv/FaceDetection/eval.py b/model_zoo/research/cv/FaceDetection/eval.py index b87ea97e66..fd42a35153 100644 --- a/model_zoo/research/cv/FaceDetection/eval.py +++ b/model_zoo/research/cv/FaceDetection/eval.py @@ -32,7 +32,7 @@ from src.config import config from src.FaceDetection.yolov3 import HwYolov3 as backbone_HwYolov3 from src.FaceDetection import voc_wrapper from src.network_define import BuildTestNetwork, get_bounding_boxes, tensor_to_brambox, \ - parse_gt_from_anno, parse_rets, calc_recall_presicion_ap + parse_gt_from_anno, parse_rets, calc_recall_precision_ap plt.switch_backend('agg') devid = int(os.getenv('DEVICE_ID')) @@ -186,7 +186,7 @@ def val(args): ret_list = parse_rets(ret_files_set) iou_thr = 0.5 - evaluate = calc_recall_presicion_ap(ground_truth, ret_list, iou_thr) + evaluate = calc_recall_precision_ap(ground_truth, ret_list, iou_thr) aps_str = '' for cls in evaluate: diff --git a/model_zoo/research/cv/FaceDetection/src/network_define.py b/model_zoo/research/cv/FaceDetection/src/network_define.py index a8df58508c..d6d541768c 100644 --- a/model_zoo/research/cv/FaceDetection/src/network_define.py +++ b/model_zoo/research/cv/FaceDetection/src/network_define.py @@ -87,7 +87,6 @@ class TrainOneStepWithLossScaleCell(nn.Cell): if scale_update_cell: self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), name="loss_scale") - self.add_flags(has_effect=True) def construct(self, data, coord_mask, conf_pos_mask, conf_neg_mask, cls_mask, t_coord, t_conf, t_cls, gt_list, coord_mask_1, conf_pos_mask_1, conf_neg_mask_1, cls_mask_1, t_coord_1, t_conf_1, t_cls_1, gt_list_1, @@ -100,16 +99,17 @@ class TrainOneStepWithLossScaleCell(nn.Cell): coord_mask_1, conf_pos_mask_1, conf_neg_mask_1, cls_mask_1, t_coord_1, t_conf_1, t_cls_1, gt_list_1, coord_mask_2, conf_pos_mask_2, conf_neg_mask_2, cls_mask_2, t_coord_2, t_conf_2, t_cls_2, gt_list_2) - # init overflow buffer init = self.alloc_status() - # clear overflow buffer - self.clear_status(init) + init = F.depend(init, loss) + clear_status = self.clear_status(init) + if sens is None: scaling_sens = self.loss_scale else: scaling_sens = sens + scaling_sens = F.depend(scaling_sens, clear_status) grads = self.grad(self.network, weights)(data, coord_mask, conf_pos_mask, conf_neg_mask, cls_mask, t_coord, t_conf, t_cls, gt_list, coord_mask_1, conf_pos_mask_1, conf_neg_mask_1, @@ -122,7 +122,9 @@ class TrainOneStepWithLossScaleCell(nn.Cell): grads = self.grad_reducer(grads) # get the overflow buffer - self.get_status(init) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) # sum overflow buffer elements, 0:not overflow , >0:overflow flag_sum = self.reduce_sum(init, (0,)) @@ -551,11 +553,11 @@ def cal_ap_voc2012(recall, precision): ap_val = 0.0 eps = 1e-6 assert len(recall) == len(precision) - lenght = len(recall) - cur_prec = precision[lenght - 1] - cur_rec = recall[lenght - 1] + length = len(recall) + cur_prec = precision[length - 1] + cur_rec = recall[length - 1] - for i in range(0, lenght - 1)[::-1]: + for i in range(0, length - 1)[::-1]: cur_prec = max(precision[i], cur_prec) if abs(recall[i] - cur_rec) > eps: ap_val += cur_prec * abs(recall[i] - cur_rec) @@ -588,8 +590,8 @@ def cal_ap_11point(recall, precision): return ap_val -def calc_recall_presicion_ap(ground_truth, ret_list, iou_thr=0.5): - '''calc_recall_presicion_ap''' +def calc_recall_precision_ap(ground_truth, ret_list, iou_thr=0.5): + '''calc_recall_precision_ap''' print('calculate [recall | persicion | ap]...') evaluate = {} for cls in ret_list: @@ -628,8 +630,8 @@ def calc_recall_presicion_ap(ground_truth, ret_list, iou_thr=0.5): fp = fp.cumsum() recall = tp / n_gt_obj - presicion = tp / (tp + fp) - ap = cal_ap_voc2012(recall, presicion) - evaluate[cls] = {'recall': recall, 'presicion': presicion, 'ap': ap} + precision = tp / (tp + fp) + ap = cal_ap_voc2012(recall, precision) + evaluate[cls] = {'recall': recall, 'precision': precision, 'ap': ap} return evaluate diff --git a/model_zoo/research/cv/centernet/src/centernet_pose.py b/model_zoo/research/cv/centernet/src/centernet_pose.py index c6c560e6a7..c558156a03 100644 --- a/model_zoo/research/cv/centernet/src/centernet_pose.py +++ b/model_zoo/research/cv/centernet/src/centernet_pose.py @@ -274,14 +274,13 @@ class CenterNetWithLossScaleCell(nn.Cell): self.cast = ops.Cast() self.alloc_status = ops.NPUAllocFloatStatus() self.get_status = ops.NPUGetFloatStatus() - self.clear_before_grad = ops.NPUClearFloatStatus() + self.clear_status = ops.NPUClearFloatStatus() self.reduce_sum = ops.ReduceSum(keep_dims=False) self.base = Tensor(1, mstype.float32) self.less_equal = ops.LessEqual() self.grad_scale = GradScale() self.loss_scale = sens - @ops.add_flags(has_effect=True) def construct(self, image, hm, reg_mask, ind, wh, kps, kps_mask, reg, hm_hp, hp_offset, hp_ind, hp_mask): """Defines the computation performed.""" @@ -292,13 +291,17 @@ class CenterNetWithLossScaleCell(nn.Cell): scaling_sens = self.cast(self.loss_scale, mstype.float32) * 2.0 / 2.0 # alloc status and clear should be right before gradoperation init = self.alloc_status() - self.clear_before_grad(init) + init = ops.depend(init, scaling_sens) + clear_status = self.clear_status(init) + scaling_sens = ops.depend(scaling_sens, clear_status) grads = self.grad(self.network, weights)(image, hm, reg_mask, ind, wh, kps, kps_mask, reg, hm_hp, hp_offset, hp_ind, hp_mask, scaling_sens) grads = self.grad_reducer(grads) grads = self.grad_scale(scaling_sens * self.degree, grads) - self.get_status(init) + init = ops.depend(init, grads) + get_status = self.get_status(init) + init = ops.depend(init, get_status) flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: flag_reduce = self.allreduce(flag_sum) diff --git a/tests/st/auto_monad/capture.py b/tests/st/auto_monad/capture.py new file mode 100644 index 0000000000..13ab61bafe --- /dev/null +++ b/tests/st/auto_monad/capture.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +import sys +import tempfile +from contextlib import contextmanager + +class Capture(): + def start(self): + self._old_stdout = sys.stdout + self._stdout_fd = self._old_stdout.fileno() + self._saved_stdout_fd = os.dup(self._stdout_fd) + self._file = sys.stdout = tempfile.TemporaryFile(mode='w+t') + self.output = '' + os.dup2(self._file.fileno(), self._stdout_fd) + + def stop(self): + os.dup2(self._saved_stdout_fd, self._stdout_fd) + os.close(self._saved_stdout_fd) + sys.stdout = self._old_stdout + self._file.seek(0) + self.output = self._file.read() + self._file.close() + +@contextmanager +def capture(cap): + cap.start() + try: + yield cap + finally: + cap.stop() + +def check_output(output, patterns): + assert output, "Capture output failed!" + for pattern in patterns: + assert output.find(pattern) != -1, "Unexpected output:\n" + output + "\n--- pattern ---\n" + pattern diff --git a/tests/st/auto_monad/test_auto_monad.py b/tests/st/auto_monad/test_auto_monad.py new file mode 100644 index 0000000000..08d4b22fdc --- /dev/null +++ b/tests/st/auto_monad/test_auto_monad.py @@ -0,0 +1,1456 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +import re +import time +import pytest +import numpy as np +import mindspore as ms +import mindspore.ops.operations as P +import mindspore.nn as nn +from mindspore.nn import Cell +from mindspore.nn import ReLU, BatchNorm2d, Conv2d, Dense, PReLU, ParameterUpdate +from mindspore.nn import Momentum, SoftmaxCrossEntropyWithLogits +from mindspore import context, Tensor +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer +from mindspore.ops.primitive import constexpr +from capture import Capture, capture, check_output + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +@pytest.fixture(name="pynative_save_graphs") +def _pynative_save_graphs(): + context.set_context(mode=context.PYNATIVE_MODE, save_graphs=True) + yield + context.set_context(mode=context.GRAPH_MODE, save_graphs=False) + clean_all_ir_files('./') + + +@pytest.fixture(name="with_save_graphs") +def _with_save_graphs(): + context.set_context(save_graphs=True) + yield + context.set_context(save_graphs=False) + clean_all_ir_files('./') + + +@pytest.mark.skip(reason="Ignore print detection") +def test_print(): + class Print(Cell): + def __init__(self): + super().__init__() + self.print = P.Print() + + def construct(self, x, y): + self.print("input_x:", x, "input_y:", y) + return x + + cap = Capture() + with capture(cap): + input_x = Tensor(3, dtype=ms.int32) + input_y = Tensor(4, dtype=ms.int32) + net = Print() + net(input_x, input_y) + time.sleep(0.1) + + patterns = {'input_x:\nTensor(shape=[], dtype=Int32, value=3)\n' + 'input_y:\nTensor(shape=[], dtype=Int32, value=4)'} + check_output(cap.output, patterns) + + +@pytest.mark.skip(reason="Ignore print detection") +def test_print_add(): + class Print_Add(Cell): + def __init__(self): + super().__init__() + self.print = P.Print() + self.add = P.Add() + + def construct(self, x, y): + x = self.add(x, y) + self.print("input_x:", x, "input_y:", y) + return x + + cap = Capture() + with capture(cap): + input_x = Tensor(3, dtype=ms.int32) + input_y = Tensor(4, dtype=ms.int32) + expect = Tensor(7, dtype=ms.int32) + net = Print_Add() + out = net(input_x, input_y) + time.sleep(0.1) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + patterns = {'input_x:\nTensor(shape=[], dtype=Int32, value=7)\n' + 'input_y:\nTensor(shape=[], dtype=Int32, value=4)'} + check_output(cap.output, patterns) + + +@pytest.mark.skip(reason="Ignore print detection") +def test_print_assign(): + class Print_Assign(Cell): + def __init__(self): + super().__init__() + self.print = P.Print() + self.para = Parameter(Tensor(1, dtype=ms.int32), name='para') + + def construct(self, x): + self.print("before:", self.para) + self.para = x + self.print("after:", self.para) + return self.para + + cap = Capture() + with capture(cap): + input_x = Tensor(3, dtype=ms.int32) + expect = Tensor(3, dtype=ms.int32) + net = Print_Assign() + out = net(input_x) + time.sleep(0.1) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + patterns = {'before:\nTensor(shape=[], dtype=Int32, value=1)', + 'after:\nTensor(shape=[], dtype=Int32, value=3)'} + check_output(cap.output, patterns) + + +@pytest.mark.skip(reason="Ignore print detection") +def test_print_assign_add(): + class Print_Assign_Add(Cell): + def __init__(self): + super().__init__() + self.print = P.Print() + self.add = P.Add() + self.para = Parameter(Tensor(1, dtype=ms.int32), name='para') + + def construct(self, x, y): + self.print("before:", self.para) + self.para = x + self.print("after:", self.para) + x = self.add(self.para, y) + return x + + cap = Capture() + with capture(cap): + input_x = Tensor(3, dtype=ms.int32) + input_y = Tensor(4, dtype=ms.int32) + expect = Tensor(7, dtype=ms.int32) + net = Print_Assign_Add() + out = net(input_x, input_y) + time.sleep(0.1) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + patterns = {'before:\nTensor(shape=[], dtype=Int32, value=1)', + 'after:\nTensor(shape=[], dtype=Int32, value=3)'} + check_output(cap.output, patterns) + + +@pytest.mark.skip(reason="Ignore print detection") +def test_print_while(): + class Print_While(Cell): + def __init__(self): + super().__init__() + self.print = P.Print() + + def construct(self, x, y): + self.print("input_x before:", x, "input_y before:", y) + while x < y: + self.print("input_x after:", x, "input_y after:", y) + x = x + 1 + return x + + cap = Capture() + with capture(cap): + input_x = Tensor(1, dtype=ms.int32) + input_y = Tensor(4, dtype=ms.int32) + expect = Tensor(4, dtype=ms.int32) + net = Print_While() + out = net(input_x, input_y) + time.sleep(0.1) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + patterns = {'input_x before:\nTensor(shape=[], dtype=Int32, value=1)\n' + 'input_y before:\nTensor(shape=[], dtype=Int32, value=4)', + 'input_x after:\nTensor(shape=[], dtype=Int32, value=1)\n' + 'input_y after:\nTensor(shape=[], dtype=Int32, value=4)', + 'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n' + 'input_y after:\nTensor(shape=[], dtype=Int32, value=4)', + 'input_x after:\nTensor(shape=[], dtype=Int32, value=3)\n' + 'input_y after:\nTensor(shape=[], dtype=Int32, value=4)'} + check_output(cap.output, patterns) + + +@pytest.mark.skip(reason="Ignore print detection") +def test_print_if(): + class Print_If(Cell): + def __init__(self): + super().__init__() + self.print = P.Print() + + def construct(self, x, y): + self.print("input_x before:", x, "input_y before:", y) + if x < y: + self.print("input_x after:", x, "input_y after:", y) + x = x + 1 + return x + + cap = Capture() + with capture(cap): + input_x = Tensor(3, dtype=ms.int32) + input_y = Tensor(4, dtype=ms.int32) + expect = Tensor(4, dtype=ms.int32) + net = Print_If() + out = net(input_x, input_y) + time.sleep(0.1) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + patterns = {'input_x before:\nTensor(shape=[], dtype=Int32, value=3)\n' + 'input_y before:\nTensor(shape=[], dtype=Int32, value=4)', + 'input_x after:\nTensor(shape=[], dtype=Int32, value=3)\n' + 'input_y after:\nTensor(shape=[], dtype=Int32, value=4)'} + check_output(cap.output, patterns) + + +@pytest.mark.skip(reason="Ignore print detection") +def test_print_assign_while(): + class Print_Assign_While(Cell): + def __init__(self): + super().__init__() + self.print = P.Print() + self.para = Parameter(Tensor(0, dtype=ms.int32), name='para') + + def construct(self, x, y): + self.print("input_x before:", x, "input_y before:", + y, "para before:", self.para) + while x < y: + self.para = x + x = self.para + 1 + self.print("input_x after:", x, "input_y after:", + y, "para after:", self.para) + return x + + cap = Capture() + with capture(cap): + input_x = Tensor(1, dtype=ms.int32) + input_y = Tensor(4, dtype=ms.int32) + expect = Tensor(4, dtype=ms.int32) + net = Print_Assign_While() + out = net(input_x, input_y) + time.sleep(0.1) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + patterns = { + 'input_x before:\nTensor(shape=[], dtype=Int32, value=1)\n' + 'input_y before:\nTensor(shape=[], dtype=Int32, value=4)\n' + 'para before:\nTensor(shape=[], dtype=Int32, value=0)', + 'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n' + 'input_y after:\nTensor(shape=[], dtype=Int32, value=4)\n' + 'para after:\nTensor(shape=[], dtype=Int32, value=1)', + 'input_x after:\nTensor(shape=[], dtype=Int32, value=3)\n' + 'input_y after:\nTensor(shape=[], dtype=Int32, value=4)\n' + 'para after:\nTensor(shape=[], dtype=Int32, value=2)', + 'input_x after:\nTensor(shape=[], dtype=Int32, value=4)\n' + 'input_y after:\nTensor(shape=[], dtype=Int32, value=4)\n' + 'para after:\nTensor(shape=[], dtype=Int32, value=3)'} + check_output(cap.output, patterns) + + +@pytest.mark.skip(reason="Ignore print detection") +def test_print_assign_if(): + class Print_Assign_If(Cell): + def __init__(self): + super().__init__() + self.print = P.Print() + self.para = Parameter(Tensor(1, dtype=ms.int32), name='para') + + def construct(self, x, y): + self.print("input_x before:", x, "input_y before:", + y, "para before:", self.para) + self.para = x + if x < y: + x = self.para + 1 + self.print("input_x after:", x, "input_y after:", + y, "para after:", self.para) + return x + + cap = Capture() + with capture(cap): + input_x = Tensor(3, dtype=ms.int32) + input_y = Tensor(4, dtype=ms.int32) + expect = Tensor(4, dtype=ms.int32) + net = Print_Assign_If() + out = net(input_x, input_y) + time.sleep(0.1) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + patterns = { + 'input_x before:\nTensor(shape=[], dtype=Int32, value=3)\n' + 'input_y before:\nTensor(shape=[], dtype=Int32, value=4)\n' + 'para before:\nTensor(shape=[], dtype=Int32, value=1)', + 'input_x after:\nTensor(shape=[], dtype=Int32, value=4)\n' + 'input_y after:\nTensor(shape=[], dtype=Int32, value=4)\n' + 'para after:\nTensor(shape=[], dtype=Int32, value=3)'} + check_output(cap.output, patterns) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_assign(): + class Assign(Cell): + def __init__(self): + super().__init__() + self.para = Parameter(Tensor(1, dtype=ms.int32), name='para') + + def construct(self, value): + self.para = value + return self.para + + input_x = Tensor(3, dtype=ms.int32) + expect = Tensor(3, dtype=ms.int32) + net = Assign() + out = net(input_x) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_assign_implicit(): + class Assign_Implicit(Cell): + def __init__(self): + super(Assign_Implicit, self).__init__() + self.b = Parameter(initializer( + 1, [5], ms.float32), name="global_step") + + def construct(self, w): + self.b = w + return self.b + + input_data = Tensor(np.ones([5]).astype(np.int32)) + net = Assign_Implicit() + out = net(input_data) + assert out.dtype == ms.float32 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_assign_write_after_read(): + class Assign_WAR(Cell): + def __init__(self): + super(Assign_WAR, self).__init__() + self.assign = P.Assign() + self.sub = P.Sub() + self.add = P.Add() + self.para = Parameter(Tensor(1, dtype=ms.int32), name='para') + self.weight = Parameter(Tensor(5, dtype=ms.int32), name='weight') + + def construct(self, x, y): + # without auto_monad, execute order is wrong: Add - Assign - Sub - Assign + # expected execute order: Add - Assign - Assign - Sub + self.para = self.add(y, x) + self.assign(self.para, y) + return self.sub(self.para, self.weight) + + input_x = Tensor(3, dtype=ms.int32) + input_y = Tensor(4, dtype=ms.int32) + expect = Tensor(-1, dtype=ms.int32) + net = Assign_WAR() + out = net(input_x, input_y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_assign_read_after_write(): + class Assign_RAW(Cell): + def __init__(self): + super(Assign_RAW, self).__init__() + self.assign_add = P.AssignAdd() + self.greater = P.Greater() + self.add = P.Add() + self.para = Parameter(Tensor(1, dtype=ms.int32), name='para') + + def construct(self, x, y): + # without auto_monad, execute order is wrong: Add - Assign - Greater - AssignAdd + # expected execute order: AssignAdd - Add - Assign + self.greater(x, y) + self.assign_add(self.para, x) + self.para = self.add(x, y) + return self.para + + input_x = Tensor(3, dtype=ms.int32) + input_y = Tensor(4, dtype=ms.int32) + expect = Tensor(7, dtype=ms.int32) + net = Assign_RAW() + out = net(input_x, input_y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_assign_if(): + class Assign_If(Cell): + def __init__(self): + super().__init__() + self.para = Parameter(Tensor(1, dtype=ms.int32), name='para') + + def construct(self, x, y): + if x < y: + self.para = x + else: + self.para = y + return self.para + + input_x = Tensor(3, dtype=ms.int32) + input_y = Tensor(4, dtype=ms.int32) + expect = Tensor(3, dtype=ms.int32) + net = Assign_If() + out = net(input_x, input_y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_if(): + class If(Cell): + def __init__(self): + super().__init__() + self.add = P.Add() + self.sub = P.Sub() + + def construct(self, x, y): + if x > y: + x = self.sub(x, y) + else: + x = self.add(x, y) + return x + + input_x = Tensor(3, dtype=ms.int32) + input_y = Tensor(4, dtype=ms.int32) + expect = Tensor(7, dtype=ms.int32) + net = If() + out = net(input_x, input_y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_while(): + class While(Cell): + def construct(self, x, y): + y = y + 4 + while x < y: + x = x + 1 + x = x + 3 + return x + + input_x = Tensor(2, dtype=ms.int32) + input_y = Tensor(14, dtype=ms.int32) + expect = Tensor(21, dtype=ms.int32) + net = While() + out = net(input_x, input_y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_assign_while(): + class Assign_While(Cell): + def __init__(self): + super().__init__() + self.para = Parameter(Tensor(1, dtype=ms.int32), name='para') + + def construct(self, x, y): + y = y + 4 + while x < y: + x = x + 1 + self.para = x + self.para = x - 1 + return self.para + + input_x = Tensor(2, dtype=ms.int32) + input_y = Tensor(14, dtype=ms.int32) + expect = Tensor(17, dtype=ms.int32) + net = Assign_While() + out = net(input_x, input_y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_for(): + class For(Cell): + def construct(self, x, y): + y = x + y + for _ in range(20): + y = y + 1 + return y + + input_x = Tensor(2, dtype=ms.int32) + input_y = Tensor(4, dtype=ms.int32) + expect = Tensor(26, dtype=ms.int32) + net = For() + out = net(input_x, input_y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.skip(reason="Ignore print detection") +def test_print_for(): + class Print_For(Cell): + def __init__(self): + super().__init__() + self.print = P.Print() + + def construct(self, x, y): + y = x + y + self.print("input_x before:", x, "input_y before:", y) + for _ in range(3): + y = y + 1 + self.print("input_x after:", x, "input_y after:", y) + return y + + cap = Capture() + with capture(cap): + input_x = Tensor(2, dtype=ms.int32) + input_y = Tensor(4, dtype=ms.int32) + expect = Tensor(9, dtype=ms.int32) + net = Print_For() + out = net(input_x, input_y) + time.sleep(0.1) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + patterns = { + 'input_x before:\nTensor(shape=[], dtype=Int32, value=2)\n' + 'input_y before:\nTensor(shape=[], dtype=Int32, value=6)', + 'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n' + 'input_y after:\nTensor(shape=[], dtype=Int32, value=7)', + 'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n' + 'input_y after:\nTensor(shape=[], dtype=Int32, value=8)', + 'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n' + 'input_y after:\nTensor(shape=[], dtype=Int32, value=9)'} + check_output(cap.output, patterns) + + +@pytest.mark.skip(reason="Ignore print detection") +def test_print_assign_for(): + class Print_Assign_For(Cell): + def __init__(self): + super().__init__() + self.print = P.Print() + self.para = Parameter(Tensor(1, dtype=ms.int32), name='para') + + def construct(self, x, y): + y = x + y + self.print("input_x before:", x, "input_y before:", + y, "para before:", self.para) + for _ in range(3): + y = y + 1 + self.para = x + y + self.print("input_x after:", x, "input_y after:", + y, "para after:", self.para) + return y + + cap = Capture() + with capture(cap): + input_x = Tensor(2, dtype=ms.int32) + input_y = Tensor(4, dtype=ms.int32) + expect = Tensor(9, dtype=ms.int32) + net = Print_Assign_For() + out = net(input_x, input_y) + time.sleep(0.1) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + patterns = { + 'input_x before:\nTensor(shape=[], dtype=Int32, value=2)\n' + 'input_y before:\nTensor(shape=[], dtype=Int32, value=6)\n' + 'para before:\nTensor(shape=[], dtype=Int32, value=1)', + 'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n' + 'input_y after:\nTensor(shape=[], dtype=Int32, value=7)\n' + 'para after:\nTensor(shape=[], dtype=Int32, value=9)', + 'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n' + 'input_y after:\nTensor(shape=[], dtype=Int32, value=8)\n' + 'para after:\nTensor(shape=[], dtype=Int32, value=10)', + 'input_x after:\nTensor(shape=[], dtype=Int32, value=2)\n' + 'input_y after:\nTensor(shape=[], dtype=Int32, value=9)\n' + 'para after:\nTensor(shape=[], dtype=Int32, value=11)'} + check_output(cap.output, patterns) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_assign_for(): + class Assign_For(Cell): + def __init__(self): + super().__init__() + self.para = Parameter(Tensor(1, dtype=ms.int32), name='para') + + def construct(self, x, y): + y = y + 4 + for _ in range(5): + x = x + y + self.para = x + return self.para + + input_x = Tensor(2, dtype=ms.int32) + input_y = Tensor(3, dtype=ms.int32) + expect = Tensor(37, dtype=ms.int32) + net = Assign_For() + out = net(input_x, input_y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@constexpr +def _check_shape(shape): + if len(shape) != 1: + raise ValueError(f"Invalid shape {shape}") + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_constexpr_check(): + class ConstexprCheck(Cell): + def __init__(self): + super(ConstexprCheck, self).__init__() + self.shape = P.Shape() + + def construct(self, x, y): + s = self.shape(x) + _check_shape(s) + x = x + y + return x + + x = Tensor([2], dtype=ms.int32) + y = Tensor([3], dtype=ms.int32) + expect = Tensor(5, dtype=ms.int32) + net = ConstexprCheck() + # Input with valid shape. + out = net(x, y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + # Input with wrong shape, exception is expected. + with pytest.raises(ValueError): + wrong_x = Tensor(np.ones((2, 2)), dtype=ms.int32) + out = net(wrong_x, y) + print(out) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_if_lambda(): + class If_Lambda(Cell): + def __init__(self): + super().__init__() + self.para = Parameter(Tensor(1, dtype=ms.int32), name='para') + + def construct(self, x, y): + out = x + if x < y: + x2 = (lambda a: a + a) + out = x2(self.para) + out = out + y + return out + + input_x = Tensor(2, dtype=ms.int32) + input_y = Tensor(3, dtype=ms.int32) + expect = Tensor(5, dtype=ms.int32) + net = If_Lambda() + out = net(input_x, input_y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_multi_assign(): + class Multi_Assign(Cell): + def __init__(self): + super().__init__() + self.assign = P.Assign() + self.para1 = Parameter(Tensor(1, dtype=ms.int32), name='para1') + self.para2 = Parameter(Tensor(2, dtype=ms.int32), name='para2') + self.para3 = Parameter(Tensor(3, dtype=ms.int32), name='para3') + + def construct(self, x, y, z): + a = self.assign(self.para1, x) + a = self.assign(self.para2, y) + a = self.assign(self.para3, z) + return self.para1 + self.para2 + a + + x = Tensor(4, dtype=ms.int32) + y = Tensor(5, dtype=ms.int32) + z = Tensor(6, dtype=ms.int32) + expect = Tensor(15, dtype=ms.int32) + net = Multi_Assign() + out = net(x, y, z) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_multi_assign_addn(): + class Multi_Assign_Addn(Cell): + def __init__(self): + super().__init__() + self.addn = P.AddN() + self.assign = P.Assign() + self.para1 = Parameter(Tensor(1.0, dtype=ms.float32), name='para1') + self.para2 = Parameter(Tensor(3.0, dtype=ms.float32), name='para2') + + def construct(self, inputs): + self.assign(self.para1, inputs) + out = self.addn((inputs, self.para1, self.para2)) + self.assign(self.para2, inputs) + out = self.addn((out, self.para1, self.para2)) + return out + + x = Tensor(9.0, dtype=ms.float32) + expect = Tensor(39.0, dtype=ms.float32) + net = Multi_Assign_Addn() + out = net(x) + np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.skip(reason="Ignore print detection") +def test_multi_assign_print(): + class Multi_Assign_Print(Cell): + def __init__(self): + super().__init__() + self.pow = P.Pow() + self.print = P.Print() + self.assign = P.Assign() + self.exponent = Tensor([2.0], ms.float32) + self.para1 = Parameter(Tensor(1.0, dtype=ms.float32), name='para1') + self.para2 = Parameter(Tensor(3.0, dtype=ms.float32), name='para2') + + def construct(self, inputs): + self.assign(self.para1, inputs) + self.assign(self.para2, self.pow(inputs, self.exponent)) + self.print(inputs) + self.print(self.para1) + self.print(self.para2) + return inputs + + x = Tensor(9.0, dtype=ms.float32) + expect = Tensor(9.0, dtype=ms.float32) + expect_para1 = Tensor(9.0, dtype=ms.float32) + expect_para2 = Tensor(81.00001, dtype=ms.float32) + net = Multi_Assign_Print() + out = net(x) + np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy()) + np.testing.assert_almost_equal( + net.para1.data.asnumpy(), expect_para1.asnumpy()) + np.testing.assert_almost_equal( + net.para2.data.asnumpy(), expect_para2.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_matmul_assign_biasadd(): + class Matmul_Assign_Biasadd(Cell): + def __init__(self): + super().__init__() + inputs = np.array([[1, 1], [1, 1]]) + self.parameter1 = Parameter( + Tensor(inputs, ms.float32), name="parameter1") + biasadd = np.array([0, -1]) + self.parameter2 = Parameter( + Tensor(biasadd, ms.float32), name="biasadd") + self.assign = P.Assign() + self.matmul = P.MatMul() + self.biasadd = P.BiasAdd() + + def construct(self, x): + self.assign(self.parameter1, x) + x = self.matmul(x, self.parameter1) + self.assign(self.parameter1, x) + x = self.biasadd(x, self.parameter2) + return x + + net = Matmul_Assign_Biasadd() + inputs = np.array([[1, 2], [3, 4]]) + out1 = net(Tensor(inputs, ms.float32)) + net = Matmul_Assign_Biasadd() + try: + context.set_context(mode=context.PYNATIVE_MODE) + out2 = net(Tensor(inputs, ms.float32)) + np.testing.assert_almost_equal(out1.asnumpy(), out2.asnumpy()) + finally: + context.set_context(mode=context.GRAPH_MODE) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_assign_while_if(): + class Assign_While_If(Cell): + def __init__(self): + super().__init__() + self.mul = P.Mul() + self.addn = P.AddN() + self.assign = P.Assign() + self.assign_sub = P.AssignSub() + self.para = Parameter(Tensor(1.0, dtype=ms.float32), name='para') + + def construct(self, x, y, z, w): + self.assign(self.para, x) + if self.para > y: + self.assign(self.para, y) + x = self.mul(x, x) + while self.para > z: + x = self.addn((x, self.para)) + self.assign_sub(self.para, w) + return x + + x = Tensor(99.0, dtype=ms.float32) + y = Tensor(44.0, dtype=ms.float32) + z = Tensor(11.0, dtype=ms.float32) + w = Tensor(1.0, dtype=ms.float32) + expect = Tensor(10725.0, dtype=ms.float32) + net = Assign_While_If() + out = net(x, y, z, w) + np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_isolate_call(): + class Net(Cell): + def __init__(self): + super().__init__() + self.para1 = Parameter(Tensor(1, dtype=ms.int32), name='para1') + self.para2 = Parameter(Tensor(2, dtype=ms.int32), name='para2') + + def construct(self, x, y): + self.setpara(x, y) + return self.para1 + self.para2 + + def setpara(self, x, y): + self.para1 = x + self.setpara2(y) + return x + + def setpara2(self, y): + self.para2 = y + return y + + x = Tensor(4, dtype=ms.int32) + y = Tensor(5, dtype=ms.int32) + expect = Tensor(9, dtype=ms.int32) + net = Net() + out = net(x, y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_assign_return_true(): + class Net(Cell): + def __init__(self): + super(Net, self).__init__() + self.para = Parameter(Tensor(1, dtype=ms.int32), name='para') + + def construct(self, x, y): + if self.mycheck(x, y): + out = x + y + else: + out = x - y + out = self.para + out + return out + + def mycheck(self, x, y): + self.setpara(x, y) + return True + + def setpara(self, x, y): + self.para = x + y + return True + + x = Tensor(2, dtype=ms.int32) + y = Tensor(3, dtype=ms.int32) + expect = Tensor(10, dtype=ms.int32) + net = Net() + out = net(x, y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_unpack_call(): + class SetPara(Cell): + def __init__(self, para): + super(SetPara, self).__init__() + self.para = para + + def construct(self, x, y): + self.para = x + y + return True + + class MyNet(Cell): + def __init__(self): + super(MyNet, self).__init__() + self.para = Parameter(Tensor(1, dtype=ms.int32), name='para') + self.set_para = SetPara(self.para) + + def construct(self, *inputs): + self.call_func(self.set_para, *inputs) + out = self.para + 1 + return out + + def call_func(self, func, *inputs): + func(*inputs) + return True + + x = Tensor(2, dtype=ms.int32) + y = Tensor(3, dtype=ms.int32) + expect = Tensor(6, dtype=ms.int32) + net = MyNet() + out = net(x, y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_tuple_of_tuple(): + class SetPara(Cell): + def __init__(self, para): + super(SetPara, self).__init__() + self.para = para + + def construct(self, x, y): + self.para = x + y + return True + + class MyNet(Cell): + def __init__(self): + super(MyNet, self).__init__() + self.para = Parameter(Tensor(1, dtype=ms.int32), name='para') + self.set_para = SetPara(self.para) + + def construct(self, x, y): + t1 = (self.set_para, x) + t2 = (t1, y) + t2[0][0](t2[1], t1[1]) + out = self.para + 1 + return out + + def call_func(self, func, *inputs): + func(*inputs) + return True + + x = Tensor(2, dtype=ms.int32) + y = Tensor(3, dtype=ms.int32) + expect = Tensor(6, dtype=ms.int32) + net = MyNet() + out = net(x, y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_write_read_write(): + class MyNet(Cell): + def __init__(self): + super(MyNet, self).__init__() + self.para1 = Parameter(Tensor(1, dtype=ms.int32), name='para1') + self.para2 = Parameter(Tensor(2, dtype=ms.int32), name='para2') + + def construct(self, x, y, x1, y1): + self.para1 = x + self.para2 = y + a = self.para1 + self.para2 + self.para1 = x1 + self.para2 = y1 + return a + self.para1 + self.para2 + + x = Tensor(3, dtype=ms.int32) + y = Tensor(4, dtype=ms.int32) + x1 = Tensor(5, dtype=ms.int32) + y1 = Tensor(6, dtype=ms.int32) + expect = Tensor(18, dtype=ms.int32) + net = MyNet() + out = net(x, y, x1, y1) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_variable_from_outer_graph(): + class MyNet(Cell): + def __init__(self): + super(MyNet, self).__init__() + self.cond = False + self.add = P.Add() + self.para = Parameter(Tensor(1, dtype=ms.int32), name='para') + + def construct(self, x, y): + b = self.para + x + a = self.para + b + if self.cond: + a = self.add(a, x) + else: + a = self.add(a, y) + return a + b + + x = Tensor(2, dtype=ms.int32) + y = Tensor(3, dtype=ms.int32) + expect = Tensor(10, dtype=ms.int32) + net = MyNet() + out = net(x, y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ctrl_while_by_while_and_if_in_first_while(): + class Net(Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU() + self.sigmoid = P.Sigmoid() + self.tanh = P.Tanh() + self.add = P.Add() + a = np.full((1,), 5, dtype=np.float32) + self.a = Parameter(Tensor(a), name="a") + b = np.full((1,), 4, dtype=np.float32) + self.b = Parameter(Tensor(b), name="b") + c = np.full((1,), 7, dtype=np.float32) + self.c = Parameter(Tensor(c), name="c") + + def construct(self, x): + out = x + while self.a < 7: + if self.a < self.c: + out = self.relu(x) + self.a += 1 + while self.c > 5: + out = self.add(out, out) + self.c -= 1 + return out + + context.set_context(mode=context.GRAPH_MODE) + input_np_a = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_me_a = Tensor(input_np_a) + net = Net() + net(input_me_a) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ctrl_if_by_while_and_while_in_first_if(): + class Net(Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU() + self.sigmoid = P.Sigmoid() + self.tanh = P.Tanh() + self.add = P.Add() + a = np.full((1,), 5, dtype=np.float32) + self.a = Parameter(Tensor(a), name="a") + b = np.full((1,), 4, dtype=np.float32) + self.b = Parameter(Tensor(b), name="b") + c = np.full((1,), 7, dtype=np.float32) + self.c = Parameter(Tensor(c), name="c") + + def construct(self, x): + out = x + if self.a < self.c: + out = self.relu(x) + while self.a < 7: + self.a += 1 + + while self.c > 5: + out = self.add(out, out) + self.c -= 1 + return out + + context.set_context(mode=context.GRAPH_MODE) + input_np_a = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_me_a = Tensor(input_np_a) + net = Net() + net(input_me_a) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ctrl_while_by_while_and_while_in_first_while(): + class Net(Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU() + self.sigmoid = P.Sigmoid() + self.tanh = P.Tanh() + self.add = P.Add() + a = np.full((1,), 5, dtype=np.float32) + self.a = Parameter(Tensor(a), name="a") + b = np.full((1,), 4, dtype=np.float32) + self.b = Parameter(Tensor(b), name="b") + c = np.full((1,), 7, dtype=np.float32) + self.c = Parameter(Tensor(c), name="c") + + def construct(self, x): + out = x + while self.a < self.c: + out = self.relu(x) + while self.b > 1: + self.b -= 1 + self.a += 1 + + while self.c > 5: + out = self.add(out, out) + self.c -= 1 + return out + + context.set_context(mode=context.GRAPH_MODE) + input_np_a = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_me_a = Tensor(input_np_a) + net = Net() + net(input_me_a) + + +def clear_json_info(): + os.system("rm -rf ./kernel_meta/*.json") + os.system("rm -rf ./kernel_meta/*.info") + + +def find_json_info(file): + result = os.system("ls -al ./kernel_meta/%s" % (file)) + return result + + +class MultiOutReluBywaySqrt(Cell): + def __init__(self): + super().__init__() + self.relu = nn.ReLU() + self.sqrt = P.Sqrt() + + def construct(self, x): + x = self.relu(x) + x = self.relu(x) + x1 = self.relu(x) + x = self.relu(x1) + y = self.sqrt(x1) + return x, y + + +class MultiOutReluSqrtBywaySqrt(Cell): + def __init__(self): + super().__init__() + self.relu = nn.ReLU() + self.sqrt = P.Sqrt() + self.sin = P.Sin() + + def construct(self, x): + x = self.relu(x) + x = self.sqrt(x) + x1 = self.relu(x) + x = self.sin(x1) + y = self.sqrt(x1) + return x, y + + +def clean_all_ir_files(folder_path): + if os.path.exists(folder_path): + for file_name in os.listdir(folder_path): + if file_name.endswith('.ir') or file_name.endswith('.dot') or \ + file_name.endswith('.dat') or file_name.endswith('.pb') or \ + file_name.startswith('trace_code_graph'): + os.remove(os.path.join(folder_path, file_name)) + + +def find_newest_validateir_file(folder_path): + ckpt_files = map(lambda f: os.path.join(folder_path, f), + filter(lambda f: re.match(r'\d+_validate_\d+.ir', f), + os.listdir(folder_path))) + return max(ckpt_files, key=os.path.getctime) + + +def read_file(): + filename = find_newest_validateir_file('./') + with open((os.path.join(filename)), 'r') as f: + content = f.read() + return content + + +# Net contain Prelu,BN,Conv,Dense which have weight value +class NetRrelu(Cell): + def __init__(self, in_channel, out_channel): + super().__init__() + self.relu = PReLU(channel=in_channel, w=0.25) + self.bn = BatchNorm2d(num_features=in_channel) + self.conv = Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=2, stride=1, has_bias=False, + weight_init='ones', pad_mode='same') + self.mean = P.ReduceMean(keep_dims=False) + self.fc = Dense(in_channels=out_channel, out_channels=out_channel, + weight_init='ones', bias_init='zeros', has_bias=True) + + def construct(self, x): + x = self.relu(x) + x = self.bn(x) + x = self.conv(x) + x = self.mean(x, (2, 3)) + x = self.fc(x) + return x + + +def check_keep_batchnorm_fp32_false(kwargs, level): + if ms.context.get_context("device_target") == "GPU": + if level == "O2": + if "keep_batchnorm_fp32" in kwargs.keys() and (not kwargs["keep_batchnorm_fp32"]): + if "cast_model_type" not in kwargs.keys() or kwargs["cast_model_type"] == ms.float16: + return True + else: + if "cast_model_type" in kwargs.keys() and kwargs["cast_model_type"] == ms.float16: + if "keep_batchnorm_fp32" not in kwargs.keys() or (not kwargs["keep_batchnorm_fp32"]): + return True + return False + + +def use_build_train_network_check_cast_num(network, level, inputs, label, cast_num, loss_flag=True, **kwargs): + diff_cast = 0 + if check_keep_batchnorm_fp32_false(kwargs, level): + diff_cast += 8 + opt = Momentum(learning_rate=0.0001, momentum=0.009, + params=network.trainable_params()) + loss = None + if loss_flag: + loss = SoftmaxCrossEntropyWithLogits(sparse=False, reduction='mean') + + train_network = ms.amp.build_train_network( + network, opt, loss, level=level, **kwargs) + out_me = train_network(inputs, label) + if context.get_context("mode") == 0: + content = read_file() + castnum = re.findall('Cast', content) + assert len(castnum) == max(cast_num - diff_cast, 0) + return out_me + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_auto_mixed_precision_train_prelunet(with_save_graphs): + net2 = NetRrelu(3, 12) + input32 = Tensor(np.ones([1, 3, 2, 2]).astype(np.float32)) + label32 = Tensor(np.zeros([1, 12]).astype(np.float32)) + use_build_train_network_check_cast_num(net2, "O2", input32, label32, 16) + + +class AssignNet(Cell): + def __init__(self): + super().__init__() + #self._save_graphs(save_graph_flag=True, save_graph_path=".") + self.relu = ReLU() + self.mean = P.ReduceMean(keep_dims=False) + self.assign_sub = P.AssignSub() + self.input_data = Parameter(initializer( + 1, [1, 3, 2, 2], ms.float32), name='value') + + def construct(self, x): + x = self.assign_sub(self.input_data, x) + x = self.relu(x) + x = self.mean(x, (2, 3)) + return x + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_auto_mixed_precision_train_021(pynative_save_graphs): + net = AssignNet() + input32 = Tensor(np.ones([1, 3, 2, 2]).astype(np.float32)) + label32 = Tensor(np.zeros([1, 3]).astype(np.float32)) + use_build_train_network_check_cast_num(net, "O0", input32, label32, 0) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_auto_mixed_precision_train_022(pynative_save_graphs): + net = AssignNet() + input32 = Tensor(np.ones([1, 3, 2, 2]).astype(np.float32)) + label32 = Tensor(np.zeros([1, 3]).astype(np.float32)) + use_build_train_network_check_cast_num(net, "O2", input32, label32, 2) + + +class MixControlNet(Cell): + def __init__(self, in_channel, x): + super().__init__() + #self._save_graphs(save_graph_flag=True, save_graph_path=".") + self.biasadd = P.BiasAdd() + self.equal = P.Equal() + self.addn = P.AddN() + self.conv = Conv2d(in_channels=in_channel, out_channels=in_channel, + kernel_size=1, stride=1, has_bias=False, + weight_init='ones', pad_mode='same') + self.bn = BatchNorm2d(num_features=in_channel) + self.assignadd = P.AssignAdd() + self.assign = P.Assign() + self.relu = ReLU() + self.mean = P.ReduceMean(keep_dims=False) + self.bias = Parameter( + Tensor(np.random.randint(2, size=(3,)).astype((np.float32))), + name="bias") + self.bias2 = Parameter(Tensor(np.ones([3]).astype(np.float32)), + name="bias2") + self.parameterupdate = ParameterUpdate(self.bias) + self.value = Tensor(np.random.randn(*(3,)), ms.float32) + self.x = x + + def construct(self, input_x): + x = self.x + z = self.x + out = self.biasadd(input_x, self.bias) + while x < 20: + update = self.parameterupdate(self.bias2) + out = self.biasadd(out, update) + if x < 10: + out = self.addn((input_x, out)) + while z < 20: + out = self.conv(out) + z = z + 1 + if x < 20: + out = self.biasadd(out, self.bias) + if x % 2 == 0: + out = self.biasadd(out, self.bias) + self.assignadd(self.bias, self.value) + out = self.bn(out) + else: + out = self.conv(out) + x = x + 1 + out = self.addn((out, out)) + out = self.mean(out, (2, 3)) + return out + + +def use_build_train_network_controlflow_check_cast_num(network, level, input_x, + label, cast_num, + sparse=False, + loss_flag=True, + **kwargs): + opt = Momentum(learning_rate=0.0001, momentum=0.009, + params=network.trainable_params()) + loss = None + if loss_flag: + loss = SoftmaxCrossEntropyWithLogits(sparse=sparse, reduction='mean') + + train_network = ms.amp.build_train_network(network, opt, loss, level=level, + **kwargs) + out_me = train_network(input_x, label) + if context.get_context("mode") == 0: + content = read_file() + castnum = re.findall('Cast', content) + assert len(castnum) == cast_num + return out_me + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_auto_mixed_precision_controlflow_auto_1(pynative_save_graphs): + net = MixControlNet(3, 5) + input_x = Tensor( + np.random.randint(2, size=(1, 3, 2, 2)).astype((np.float32))) + label = Tensor(np.zeros([1, 3]).astype(np.float32)) + if ms.context.get_context("device_target") == "Ascend": + cast_num = 77 + if ms.context.get_context("device_target") == "GPU": + cast_num = 73 + use_build_train_network_controlflow_check_cast_num(net, "auto", input_x, + label, cast_num) + + +# op_cast should be located in order_list after abstract_specialize. +# Besides Ascend, it can work on CPU. +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_if_cast(): + class Net(nn.Cell): + def __init__(self, cond1): + super().__init__() + self.cond1 = cond1 + self.op_cast = P.Cast() + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + + def construct(self, beta1, beta2): + z_local = self.op_cast(self.z, ms.float16) + self.z = beta2 + if self.cond1: + out = z_local + beta1 + else: + out = z_local - beta1 + + return out + + context.set_context(save_graphs=False) + net = Net(True) + beta1 = Tensor(np.array([2]).astype(np.float32)) + beta2 = Tensor(np.array([10]).astype(np.float32)) + r1 = net(beta1, beta2) + expect = Tensor(np.array([3]).astype(np.float32)) + np.testing.assert_array_equal(r1.asnumpy(), expect.asnumpy()) diff --git a/tests/st/auto_monad/test_auto_monad_gpu.py b/tests/st/auto_monad/test_auto_monad_gpu.py new file mode 100644 index 0000000000..270d5f13bd --- /dev/null +++ b/tests/st/auto_monad/test_auto_monad_gpu.py @@ -0,0 +1,576 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +import re +import subprocess +import pytest +import numpy as np +import mindspore as ms +import mindspore.ops.operations as P +from mindspore.nn import Cell +from mindspore.nn import ReLU, BatchNorm2d, Conv2d, ParameterUpdate +from mindspore.nn import Momentum +from mindspore.nn import SoftmaxCrossEntropyWithLogits +from mindspore import amp +from mindspore import context, Tensor +from mindspore.common import ParameterTuple +from mindspore.common.parameter import Parameter +from mindspore.ops.composite import GradOperation + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class _Grad(Cell): + def __init__(self, grad, network, wrt_params=False, real_inputs_count=None): + super().__init__() + self.network = network + self.grad = grad + self.sens_param = self.grad.sens_param + self.wrt_params = wrt_params + self.real_inputs_count = real_inputs_count + if self.wrt_params: + self.params = ParameterTuple(self.network.trainable_params()) + + def construct(self, *inputs): + if self.real_inputs_count is None or self.sens_param is False: + if self.wrt_params: + return self.grad(self.network, self.params)(*inputs) + return self.grad(self.network)(*inputs) + + real_inputs = inputs[:self.real_inputs_count] + sense_param_inputs = inputs[self.real_inputs_count:] + if self.wrt_params: + return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs) + return self.grad(self.network)(*real_inputs, sense_param_inputs) + + +class GradOfAllInputs(_Grad): + ''' + get grads of all inputs + ''' + + def __init__(self, network, sens_param=True, real_inputs_count=None): + super().__init__(grad=GradOperation(get_all=True, sens_param=sens_param), + network=network, real_inputs_count=real_inputs_count) + + +class GradOfAllInputsAndParams(_Grad): + ''' + get grads of all inputs and params + ''' + + def __init__(self, network, sens_param=True, real_inputs_count=None): + super().__init__(grad=GradOperation(get_all=True, get_by_list=True, sens_param=sens_param), + network=network, wrt_params=True, real_inputs_count=real_inputs_count) + + +def _count_unequal_element(data_expected, data_me, rtol, atol): + assert data_expected.shape == data_me.shape + total_count = len(data_expected.flatten()) + error = np.abs(data_expected - data_me) + greater = np.greater(error, atol + np.abs(data_me)*rtol) + loss_count = np.count_nonzero(greater) + assert (loss_count/total_count) < rtol, \ + "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".\ + format(data_expected[greater], data_me[greater], error[greater]) + + +def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): + if np.any(np.isnan(data_expected)): + assert np.allclose(data_expected, data_me, rtol, + atol, equal_nan=equal_nan) + elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan): + _count_unequal_element(data_expected, data_me, rtol, atol) + else: + assert True + + +def clear_files(): + os.system("rm verbose_ir_files/*") + + +def find_files(file, para): + output = subprocess.check_output( + ["grep '%s' verbose_ir_files/%s | wc -l" % (para, file)], + shell=True) + out = str(output, 'utf-8').strip() + return out + + +class SideEffectCastAll(Cell): + def __init__(self): + super().__init__() + self.cast = P.Cast() + self.dtype = ms.float16 + np.random.seed(5) + inputs1 = np.random.randn(5, 5) + inputs2 = np.random.randn(5, 5) + self.parameter_a = Parameter(Tensor(inputs1, ms.float32), name="a") + self.parameter_b = Parameter(Tensor(inputs2, ms.float32), name="b") + self.assign = P.Assign() + + def construct(self, x, y): + self.assign(self.parameter_a, x) + self.assign(self.parameter_b, y) + out_a = self.cast(self.parameter_a, self.dtype) + out_b = self.cast(self.parameter_b, self.dtype) + return out_a, out_b + + +# @pytest.mark.level0 +# @pytest.mark.platform_x86_gpu_training +# @pytest.mark.env_onecard +@pytest.mark.skip(reason="not stable") +def test_side_effect_castall(): + clear_files() + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + net = SideEffectCastAll() + inputs1 = np.random.randn(5, 5) + inputs2 = np.random.randn(5, 5) + net(Tensor(inputs1, ms.float32), Tensor(inputs2, ms.float32)) + result = find_files('hwopt*cast_all*.ir', 'CastAll') + assert result == '2' + + +class SideEffectControlFlowAssignDependWhileNet(Cell): + def __init__(self): + super().__init__() + self.parameter1 = Parameter( + Tensor([199.0], ms.float32), name="parameter1") + self.assign = P.Assign() + self.assignadd = P.AssignAdd() + self.addn = P.AddN() + + def construct(self, x, y, z): + self.assign(self.parameter1, x) + while self.parameter1 < y: + x = self.addn((x, x)) + self.assignadd(self.parameter1, z) + return x + + def grad_mindspore_impl(self, params1, params2, params3, grad_ys): + grad_net = GradOfAllInputsAndParams(self) + grad_net.set_train() + grad_out = grad_net(params1, params2, params3, grad_ys) + return grad_out + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_side_effect_control_flow_assign_depend_while_net(): + net = SideEffectControlFlowAssignDependWhileNet() + context.set_context(mode=context.GRAPH_MODE) + out1 = net(Tensor([9.0], ms.float32), Tensor( + [99.0], ms.float32), Tensor([1.0], ms.float32)) + net = SideEffectControlFlowAssignDependWhileNet() + context.set_context(mode=context.PYNATIVE_MODE) + out2 = net(Tensor([9.0], ms.float32), Tensor( + [99.0], ms.float32), Tensor([1.0], ms.float32)) + allclose_nparray(out1.asnumpy(), out2.asnumpy(), 0.001, 0.001) + + +class Addn(Cell): + def __init__(self): + super().__init__() + self.parameter3 = Parameter(Tensor([1.0], ms.float32), + name="parameter3") + self.parameter4 = Parameter(Tensor([3.0], ms.float32), + name="parameter4") + self.addn = P.AddN() + + def construct(self, inputs): + out = self.addn((inputs, self.parameter3, self.parameter4)) + return out + + +class Relu(Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU() + + def construct(self, inputs): + out = self.relu(inputs) + return out + + +class SideEffectTwoAssignTwoAddnDependencyNet(Cell): + def __init__(self): + super().__init__() + self.parameter1 = Parameter(Tensor([1.0], ms.float32), + name="parameter1") + self.parameter2 = Parameter(Tensor([3.0], ms.float32), + name="parameter2") + self.assign = P.Assign() + self.addN = P.AddN() + + def construct(self, inputs): + self.assign(self.parameter1, inputs) + out = self.addN((inputs, self.parameter1, self.parameter2)) + self.assign(self.parameter2, inputs) + out = self.addN((out, self.parameter1, self.parameter2)) + return out + + def grad_mindspore_impl(self, params, grad_ys): + grad_net = GradOfAllInputsAndParams(self) + grad_net.set_train() + grad_out = grad_net(params, grad_ys) + return grad_out + + +# an infinite loop exists. +@pytest.mark.skip(reason="not supported yet") +def test_ctrl_while_by_while_and_if_in_first_while(): + class Net(Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU() + self.sigmoid = P.Sigmoid() + self.tanh = P.Tanh() + self.add = P.Add() + a = np.full((1,), 5, dtype=np.float32) + self.a = Parameter(Tensor(a), name="a") + b = np.full((1,), 4, dtype=np.float32) + self.b = Parameter(Tensor(b), name="b") + c = np.full((1,), 7, dtype=np.float32) + self.c = Parameter(Tensor(c), name="c") + + def construct(self, x): + out = x + while self.a < 7: + if self.a < self.c: + out = self.relu(x) + self.a += 1 + while self.c > 5: + out = self.add(out, out) + self.c -= 1 + return out + + context.set_context(mode=context.GRAPH_MODE) + input_np_a = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_me_a = Tensor(input_np_a) + net = Net() + net(input_me_a) + + +# an infinite loop exists. +@pytest.mark.skip(reason="not supported yet") +def test_ctrl_while_by_while_and_while_in_first_while(): + class Net(Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU() + self.sigmoid = P.Sigmoid() + self.tanh = P.Tanh() + self.add = P.Add() + a = np.full((1,), 5, dtype=np.float32) + self.a = Parameter(Tensor(a), name="a") + b = np.full((1,), 4, dtype=np.float32) + self.b = Parameter(Tensor(b), name="b") + c = np.full((1,), 7, dtype=np.float32) + self.c = Parameter(Tensor(c), name="c") + + def construct(self, x): + out = x + while self.a < self.c: + out = self.relu(x) + while self.b > 1: + self.b -= 1 + self.a += 1 + + while self.c > 5: + out = self.add(out, out) + self.c -= 1 + return out + + context.set_context(mode=context.GRAPH_MODE) + input_np_a = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_me_a = Tensor(input_np_a) + net = Net() + net(input_me_a) + + +class InplaceNet(Cell): + def __init__(self): + super().__init__() + self.bn1 = BatchNorm2d(num_features=4, eps=1e-4, + momentum=0.9, gamma_init=1, beta_init=0, + moving_mean_init=0, moving_var_init=1, data_format="NHWC") + self.bn2 = BatchNorm2d(num_features=4, eps=1e-4, + momentum=0.9, gamma_init=1, beta_init=0, + moving_mean_init=0, moving_var_init=1, data_format="NHWC") + self.add = P.Add() + self.relu = ReLU() + self.conv2d1 = Conv2d(in_channels=4, out_channels=4, + kernel_size=2, data_format="NHWC") + self.conv2d2 = Conv2d(in_channels=4, out_channels=4, + kernel_size=2, data_format="NHWC") + self.conv2d3 = Conv2d(in_channels=4, out_channels=4, + kernel_size=2, data_format="NHWC") + self.conv2d4 = Conv2d(in_channels=4, out_channels=4, + kernel_size=2, data_format="NHWC") + + def construct(self, input_x): + tmp_c1 = self.conv2d1(input_x) + tmp_c2 = self.conv2d2(input_x) + tmp_x = self.bn1(tmp_c1) + tmp_y = self.bn2(tmp_c2) + tmp_w = self.add(tmp_x, tmp_y) + tmp_w = self.relu(tmp_w) + + tmp_c1 = self.conv2d3(tmp_w) + tmp_c2 = self.conv2d4(tmp_w) + output = self.add(tmp_c1, tmp_c2) + return output + + +# @pytest.mark.level0 +# @pytest.mark.platform_x86_gpu_training +# @pytest.mark.env_onecard +@pytest.mark.skip(reason="not stable") +def test_ir_fusion_inplace_bn_conv_conv(): + clear_files() + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + input_np = np.random.uniform(0.0, 255.0, + size=[4, 4, 4, 4]).astype(np.float32) + label = np.ones([4, 4, 4, 4]).astype(np.float32) + net = InplaceNet() + loss = SoftmaxCrossEntropyWithLogits(sparse=False) + opt = Momentum(learning_rate=0.01, momentum=0.9, + params=filter(lambda x: x.requires_grad, net.get_parameters())) + net = amp.build_train_network(net, opt, loss, level="O2", + keep_batchnorm_fp32=False) + net.set_train() + net(Tensor(input_np), Tensor(label)) + find_accum = find_files("hwopt*cudnn_inplace*ir", + "inplace_algo: accumulation") + find_cover = find_files("hwopt*cudnn_inplace*ir", + "inplace_algo: cover") + assert find_accum == '1' + assert find_cover == '1' + + +def clean_all_ir_files(folder_path): + if os.path.exists(folder_path): + for file_name in os.listdir(folder_path): + if file_name.endswith('.ir') or file_name.endswith('.dot') or \ + file_name.endswith('.dat'): + os.remove(os.path.join(folder_path, file_name)) + + +def find_newest_validateir_file(folder_path): + ckpt_files = map(lambda f: os.path.join(folder_path, f), + filter(lambda f: re.match(r'\d+_validate_\d+.ir', f), + os.listdir(folder_path))) + return max(ckpt_files, key=os.path.getctime) + + +def read_file(): + filename = find_newest_validateir_file('./') + with open((os.path.join(filename)), 'r') as f: + content = f.read() + clean_all_ir_files('./') + return content + + +class Add(Cell): + def __init__(self): + super().__init__() + self.add = P.Add() + + def construct(self, x, y): + return self.add(x, y) + + +class MixControlNet(Cell): + def __init__(self, in_channel, x): + super().__init__() + #self._save_graphs(save_graph_flag=True, save_graph_path=".") + self.biasadd = P.BiasAdd() + self.equal = P.Equal() + self.addn = P.AddN() + self.conv = Conv2d(in_channels=in_channel, out_channels=in_channel, + kernel_size=1, stride=1, has_bias=False, + weight_init='ones', pad_mode='same') + self.bn = BatchNorm2d(num_features=in_channel) + self.controldepend = P.ControlDepend() + self.assignadd = P.AssignAdd() + self.assign = P.Assign() + self.relu = ReLU() + self.mean = P.ReduceMean(keep_dims=False) + self.bias = Parameter( + Tensor(np.random.randint(2, size=(3,)).astype((np.float32))), + name="bias") + self.bias2 = Parameter(Tensor(np.ones([3]).astype(np.float32)), + name="bias2") + self.parameterupdate = ParameterUpdate(self.bias) + self.value = Tensor(np.random.randn(*(3,)), ms.float32) + self.x = x + + def construct(self, input_x): + x = self.x + z = self.x + out = self.biasadd(input_x, self.bias) + while x < 20: + update = self.parameterupdate(self.bias2) + out = self.biasadd(out, update) + if x < 10: + out = self.addn((input_x, out)) + while z < 20: + out = self.conv(out) + z = z + 1 + if x < 20: + out = self.biasadd(out, self.bias) + if x % 2 == 0: + out = self.biasadd(out, self.bias) + assign = self.assignadd(self.bias, self.value) + self.controldepend(assign, out) + out = self.bn(out) + else: + out = self.conv(out) + x = x + 1 + out = self.addn((out, out)) + out = self.mean(out, (2, 3)) + return out + + +def use_build_train_network_controlflow_check_cast_num(network, level, input_x, + label, cast_num, + sparse=False, + loss_flag=True, + **kwargs): + opt = Momentum(learning_rate=0.0001, momentum=0.009, + params=network.trainable_params()) + loss = None + if loss_flag: + loss = SoftmaxCrossEntropyWithLogits(sparse=sparse, reduction='mean') + + train_network = ms.amp.build_train_network(network, opt, loss, level=level, + **kwargs) + out_me = train_network(input_x, label) + if context.get_context("mode") == 0: + content = read_file() + castnum = re.findall('Cast', content) + assert len(castnum) == cast_num + return out_me + + +# @pytest.mark.level0 +# @pytest.mark.platform_x86_gpu_training +# @pytest.mark.env_onecard +@pytest.mark.skip(reason="not stable") +def test_auto_mixed_precision_controlflow_auto_1(): + context.set_context(mode=context.PYNATIVE_MODE, save_graphs=True) + net = MixControlNet(3, 5) + input_x = Tensor( + np.random.randint(2, size=(1, 3, 2, 2)).astype((np.float32))) + label = Tensor(np.zeros([1, 3]).astype(np.float32)) + if ms.context.get_context("device_target") == "Ascend": + cast_num = 77 + if ms.context.get_context("device_target") == "GPU": + cast_num = 73 + use_build_train_network_controlflow_check_cast_num(net, "auto", input_x, + label, cast_num) + + +# @pytest.mark.level0 +# @pytest.mark.platform_x86_gpu_training +# @pytest.mark.env_onecard +@pytest.mark.skip(reason="not stable") +def test_updatestate_between_assigns(): + class UpdateState_Assigns(Cell): + def __init__(self): + super().__init__() + self.para1 = Parameter(Tensor(1, dtype=ms.int32), name='para1') + self.para2 = Parameter(Tensor(3, dtype=ms.int32), name='para2') + + def construct(self, value1, value2): + self.para1 = value1 + self.para2 = value2 + return self.para2 + + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + input_x = Tensor(10, dtype=ms.int32) + input_y = Tensor(30, dtype=ms.int32) + expect = Tensor(30, dtype=ms.int32) + net = UpdateState_Assigns() + out = net(input_x, input_y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + if ms.context.get_context('mode') == 0: + content = read_file() + updatestate_num = re.findall('UpdateState', content) + assert len(updatestate_num) == 1 + + +# @pytest.mark.level0 +# @pytest.mark.platform_x86_gpu_training +# @pytest.mark.env_onecard +@pytest.mark.skip(reason="not stable") +def test_updatestate_between_maketuple_assign(): + class UpdateState_MakeTuple_Assign(Cell): + def __init__(self): + super().__init__() + self.para1 = Parameter(Tensor(1, dtype=ms.int32), name='para1') + self.para2 = Parameter(Tensor(3, dtype=ms.int32), name='para2') + self.para3 = Parameter(Tensor(5, dtype=ms.int32), name='para3') + + def construct(self, value1, value2, value3): + (self.para1, self.para2) = (value1, value2) + self.para3 = value3 + return self.para3 + + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + input_x = Tensor(10, dtype=ms.int32) + input_y = Tensor(30, dtype=ms.int32) + input_z = Tensor(50, dtype=ms.int32) + expect = Tensor(50, dtype=ms.int32) + net = UpdateState_MakeTuple_Assign() + out = net(input_x, input_y, input_z) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + if ms.context.get_context('mode') == 0: + content = read_file() + updatestate_num = re.findall('UpdateState', content) + assert len(updatestate_num) == 1 + + +# @pytest.mark.level0 +# @pytest.mark.platform_x86_gpu_training +# @pytest.mark.env_onecard +@pytest.mark.skip(reason="not stable") +def test_updatestate_between_assign_maketuple(): + class UpdateState_Assign_MakeTuple(Cell): + def __init__(self): + super().__init__() + self.para1 = Parameter(Tensor(1, dtype=ms.int32), name='para1') + self.para2 = Parameter(Tensor(3, dtype=ms.int32), name='para2') + self.para3 = Parameter(Tensor(5, dtype=ms.int32), name='para3') + + def construct(self, value1, value2, value3): + self.para1 = value1 + (self.para2, self.para3) = (value2, value3) + return self.para3 + + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + input_x = Tensor(10, dtype=ms.int32) + input_y = Tensor(30, dtype=ms.int32) + input_z = Tensor(50, dtype=ms.int32) + expect = Tensor(50, dtype=ms.int32) + net = UpdateState_Assign_MakeTuple() + out = net(input_x, input_y, input_z) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + if ms.context.get_context('mode') == 0: + content = read_file() + updatestate_num = re.findall('UpdateState', content) + assert len(updatestate_num) == 1 diff --git a/tests/st/auto_monad/test_auto_monad_mindtester.py b/tests/st/auto_monad/test_auto_monad_mindtester.py new file mode 100644 index 0000000000..459266320f --- /dev/null +++ b/tests/st/auto_monad/test_auto_monad_mindtester.py @@ -0,0 +1,646 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +import pytest +import numpy as np +import mindspore as ms +import mindspore.ops.operations as P +from mindspore.nn import Cell +from mindspore import context, Tensor +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer +from mindspore.train.model import Model +from mindspore.ops.composite import GradOperation +from mindspore.common import ParameterTuple + + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class _Grad(Cell): + def __init__(self, grad, network, wrt_params=False, real_inputs_count=None): + super().__init__() + self.network = network + self.grad = grad + self.sens_param = self.grad.sens_param + self.wrt_params = wrt_params + self.real_inputs_count = real_inputs_count + if self.wrt_params: + self.params = ParameterTuple(self.network.trainable_params()) + + def construct(self, *inputs): + if self.real_inputs_count is None or self.sens_param is False: + if self.wrt_params: + return self.grad(self.network, self.params)(*inputs) + return self.grad(self.network)(*inputs) + + real_inputs = inputs[:self.real_inputs_count] + sense_param_inputs = inputs[self.real_inputs_count:] + if self.wrt_params: + return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs) + return self.grad(self.network)(*real_inputs, sense_param_inputs) + + +class GradOfFirstInput(_Grad): + """ + get grad of first input + """ + + def __init__(self, network, sens_param=True, real_inputs_count=None): + super().__init__(grad=GradOperation(sens_param=sens_param), + network=network, real_inputs_count=real_inputs_count) + + +class GradOfAllInputs(_Grad): + ''' + get grads of all inputs + ''' + + def __init__(self, network, sens_param=True, real_inputs_count=None): + super().__init__(grad=GradOperation(get_all=True, sens_param=sens_param), + network=network, real_inputs_count=real_inputs_count) + + +class GradOfAllInputsAndParams(_Grad): + ''' + get grads of all inputs and params + ''' + + def __init__(self, network, sens_param=True, real_inputs_count=None): + super().__init__(grad=GradOperation(get_all=True, get_by_list=True, sens_param=sens_param), + network=network, wrt_params=True, real_inputs_count=real_inputs_count) + + +def _count_unequal_element(data_expected, data_me, rtol, atol): + assert data_expected.shape == data_me.shape + total_count = len(data_expected.flatten()) + error = np.abs(data_expected - data_me) + greater = np.greater(error, atol + np.abs(data_me)*rtol) + loss_count = np.count_nonzero(greater) + assert (loss_count/total_count) < rtol, \ + "\ndata_expected_std:{0}\ndata_me_error:{1}\nloss:{2}".\ + format(data_expected[greater], data_me[greater], error[greater]) + + +def allclose_nparray(data_expected, data_me, rtol, atol, equal_nan=True): + if np.any(np.isnan(data_expected)): + assert np.allclose(data_expected, data_me, rtol, + atol, equal_nan=equal_nan) + elif not np.allclose(data_expected, data_me, rtol, atol, equal_nan=equal_nan): + _count_unequal_element(data_expected, data_me, rtol, atol) + else: + assert True + + +class ControlGraphSupportNotEqual(Cell): + def construct(self, x, y, z, input_data): + if x != y: + out = input_data + input_data + else: + out = input_data - input_data + if x == z: + out2 = input_data * input_data + else: + out2 = input_data / input_data + if x == z: + out3_f = (lambda a: a+a) + out3 = out3_f(input_data) + else: + out3_f = (lambda a: a+a+a) + out3 = out3_f(input_data) + return out, out2, out3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ctrl_if_while_graph_support_not_equal_true(): + x = np.array(0).astype(np.float32) + y = np.array(3).astype(np.float32) + input_shape = (512, 512, 7, 7) + input_data = np.random.randn(*input_shape).astype(np.float32) + net = ControlGraphSupportNotEqual() + model = Model(net) + out_me = model.predict(Tensor(x), Tensor(y), Tensor(x), Tensor(input_data)) + out = input_data + input_data + out2 = input_data * input_data + out3 = input_data + input_data + allclose_nparray(out, out_me[0].asnumpy(), 0.0001, 0.0001) + allclose_nparray(out2, out_me[1].asnumpy(), 0.0001, 0.0001) + allclose_nparray(out3, out_me[2].asnumpy(), 0.0001, 0.0001) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ctrl_if_while_graph_support_not_equal_false(): + x = np.array(0).astype(np.float32) + y = np.array(0).astype(np.float32) + z = np.array(3).astype(np.float32) + input_shape = (512, 512, 7, 7) + input_data = np.random.randn(*input_shape).astype(np.float32) + net = ControlGraphSupportNotEqual() + model = Model(net) + out_me = model.predict(Tensor(x), Tensor(y), Tensor(z), Tensor(input_data)) + out = input_data - input_data + out2 = input_data / input_data + out3 = input_data + input_data + input_data + allclose_nparray(out, out_me[0].asnumpy(), 0.0001, 0.0001) + allclose_nparray(out2, out_me[1].asnumpy(), 0.0001, 0.0001) + allclose_nparray(out3, out_me[2].asnumpy(), 0.0001, 0.0001) + + +class ControlBprop(Cell): + def construct(self, x, y, z, input_data): + if x != y: + out = input_data + input_data + else: + out = input_data - input_data + if x == z: + out2 = input_data * input_data + else: + out2 = input_data / input_data + if x == z: + out3_f = (lambda a: a+a) + out3 = out3_f(input_data) + else: + out3_f = (lambda a: a+a+a) + out3 = out3_f(input_data) + return out, out2, out3 + + def bprop(self, x, y, z, input_data, out, dout): + return x*2, y*3, z, input_data*5.1 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ctrl_if_while_bprop_true(): + x = np.array(0).astype(np.float32) + y = np.array(3).astype(np.float32) + input_shape = (512, 512, 7, 7) + input_data = np.random.randn(*input_shape).astype(np.float32) + net = ControlBprop() + grad_net = GradOfAllInputs(net, sens_param=False) + grad_net.set_train() + grads = grad_net(Tensor(x), Tensor(y), Tensor(x), Tensor(input_data)) + allclose_nparray(x*2, grads[0].asnumpy(), 0.0000, 0.0000) + allclose_nparray(y*3, grads[1].asnumpy(), 0.0000, 0.0000) + allclose_nparray(x, grads[2].asnumpy(), 0.0000, 0.0000) + allclose_nparray(input_data*5.1, grads[3].asnumpy(), 0.0000, 0.0000) + + +class TwoInput(Cell): + def __init__(self): + super().__init__() + self.op = P.Mul() + + def construct(self, x, y): + x = self.op(x, y) + return x + + +class InlineBpropTwoInput1(Cell): + def __init__(self): + super().__init__() + self.f = TwoInput() + self.f.set_grad() + self.grad = GradOfAllInputs(self.f, sens_param=False) + + def construct(self, x, y): + if x > y: + x = self.f(x, y) + else: + x = self.f(x, y) + return x + + def bprop(self, x, y, out, dout): + if x > y: + grads = self.grad(x, y) + else: + grads = self.grad(x, y) + return grads[0]*2, grads[1]*2 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ctrl_if_while_bprop_inlinebprop_twoinput(): + net = InlineBpropTwoInput1() + input1 = Tensor(np.array(2).astype(np.float32)) + input2 = Tensor(np.array(1).astype(np.float32)) + grad_net = GradOfAllInputs(net, sens_param=False) + grad_net.set_train() + grads = grad_net(input1, input2) + allclose_nparray(input1.asnumpy()*2, grads[1].asnumpy(), 0, 0) + allclose_nparray(input2.asnumpy()*2, grads[0].asnumpy(), 0, 0) + + +class ControlOneIfOneParaOneAddn(Cell): + def __init__(self, input_shape): + super().__init__() + self.addn = P.AddN() + self.assign = P.Assign() + self.inputdata = Parameter(initializer( + 1, input_shape, ms.float32), name="global_step") + + def construct(self, x, y, input_data): + if x > y: + out = self.inputdata + else: + out = self.addn([input_data, input_data, input_data]) + if x > y: + out = self.assign(self.inputdata, input_data) + return out + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_ctrl_if_para_addn_true(): + x = Tensor(1, ms.float32) + y = Tensor(0, ms.float32) + input_shape = (1024, 512, 7, 7) + input_data = np.random.randn(*input_shape).astype(np.float32) + net = ControlOneIfOneParaOneAddn(input_shape) + out = net(x, y, Tensor(input_data)) + allclose_nparray(input_data[0], out.asnumpy()[0], 0.0001, 0.0001) + + +class AddnCell(Cell): + def __init__(self): + super().__init__() + self.addn = P.AddN() + + def construct(self, x): + x = self.addn((x, x)) + return x + + +class SideEffectMemoryCellAddnNet(Cell): + def __init__(self): + super().__init__() + self.para = Parameter(Tensor([1.0], ms.float32), name="para") + self.assign = P.Assign() + self.addn = P.AddN() + self.addn1 = AddnCell() + + def construct(self, x): + x = self.addn1(x) + self.assign(self.para, x) + out = self.addn((self.para, x)) + return out + + def grad_mindspore_impl(self, params, grad_ys): + grad_net = GradOfAllInputsAndParams(self) + grad_net.set_train() + grad_out = grad_net(params, grad_ys) + return grad_out + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_memory_addn(): + net = SideEffectMemoryCellAddnNet() + grad_ys = Tensor([18.0], ms.float32) + inputs = Tensor([9.0], ms.float32) + net.grad_mindspore_impl(inputs, grad_ys) + + +class SideEffectIOCellAddnNet(Cell): + def __init__(self): + super().__init__() + self.para1 = Parameter(Tensor([1.0], ms.float32), name="para1") + self.para2 = Parameter(Tensor([3.0], ms.float32), name="para2") + self.print = P.Print() + self.addn = AddnCell() + + def construct(self, x): + self.print("para1:", self.para1) + self.print("para2:", self.para2) + x = self.addn(x) + return x + + def grad_mindspore_impl(self, params, grad_ys): + grad_net = GradOfAllInputsAndParams(self) + grad_net.set_train() + grad_out = grad_net(params, grad_ys) + return grad_out + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_io_addn(): + net = SideEffectIOCellAddnNet() + grad_ys = Tensor([18.0], ms.float32) + inputs = Tensor([9.0], ms.float32) + net.grad_mindspore_impl(inputs, grad_ys) + + +class SideEffectReturnParameterNet(Cell): + def __init__(self): + super().__init__() + self.para = Parameter(Tensor([1.0], ms.float32), name="para") + self.assign = P.Assign() + self.addn = P.AddN() + self.relu = P.ReLU() + + def construct(self, inputs): + p1 = self.assign(self.para, inputs) + out = self.addn((inputs, inputs, inputs)) + out = self.relu(out) + return p1 + + def grad_mindspore_impl(self, params, grad_ys): + grad_net = GradOfAllInputsAndParams(self) + grad_net.set_train() + grad_out = grad_net(params, grad_ys) + return grad_out + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_grad_read_dependency_return_parameter(): + net = SideEffectReturnParameterNet() + grad_ys = Tensor([18.0], ms.float32) + inputs = Tensor([9.0], ms.float32) + net.grad_mindspore_impl(inputs, grad_ys) + + +class SideEffectAssignAddnReluReturnParNet(Cell): + def __init__(self): + super().__init__() + self.parameter1 = Parameter( + Tensor([1.0], ms.float32), name="parameter1") + self.assign = P.Assign() + self.addN = P.AddN() + self.relu = P.ReLU() + + def construct(self, inputs): + p1 = self.assign(self.parameter1, inputs) + out = self.addN((inputs, inputs, inputs)) + out = self.relu(out) + return p1 + + def grad_mindspore_impl(self, params, grad_ys): + grad_net = GradOfAllInputsAndParams(self) + grad_net.set_train() + grad_out = grad_net(params, grad_ys) + return grad_out + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_side_effect_grad_read_dependency_assign_addn_relu_return_parameter(): + net = SideEffectAssignAddnReluReturnParNet() + grad_ys = Tensor([18.0], ms.float32) + inputs = Tensor([9.0], ms.float32) + out1 = net.grad_mindspore_impl(inputs, grad_ys) + net = SideEffectAssignAddnReluReturnParNet() + try: + context.set_context(mode=context.PYNATIVE_MODE) + out2 = net.grad_mindspore_impl(inputs, grad_ys) + allclose_nparray(out1[0][0].asnumpy(), out2[0] + [0].asnumpy(), 0.001, 0.001) + allclose_nparray(out1[1][0].asnumpy(), out2[1] + [0].asnumpy(), 0.001, 0.001) + finally: + context.set_context(mode=context.GRAPH_MODE) + + +class SideEffectPrintInHighOrdeAddnNet(Cell): + def __init__(self): + super().__init__() + self.parameter1 = Parameter( + Tensor([1.0], ms.float32), name="parameter1") + self.parameter2 = Parameter( + Tensor([3.0], ms.float32), name="parameter2") + self.assign = P.Assign() + self.addn = P.AddN() + self.mul = P.Mul() + self.print = P.Print() + + def construct(self, x): + self.high_order_func() + out = self.addn((self.parameter1, x, self.parameter2)) + return out + + def high_order_func(self): + self.print("parameter1: ", self.parameter1) + self.print("parameter2: ", self.parameter2) + return True + + def grad_mindspore_impl(self, params, grad_ys): + grad_net = GradOfAllInputsAndParams(self) + grad_net.set_train() + grad_out = grad_net(params, grad_ys) + return grad_out + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_side_effect_high_order_print_in_high_order_net(): + print_file = os.getcwd()+"/test_side_effect_high_order_print_in_high_order_net.data" + context.set_context(print_file_path=print_file) + net = SideEffectPrintInHighOrdeAddnNet() + out1 = net(Tensor([9.0], ms.float32)) + net = SideEffectPrintInHighOrdeAddnNet() + try: + context.set_context(mode=context.PYNATIVE_MODE) + out2 = net(Tensor([9.0], ms.float32)) + allclose_nparray(out1.asnumpy(), out2.asnumpy(), 0.001, 0.001) + finally: + context.set_context(mode=context.GRAPH_MODE) + + +class SideEffectControlFlowAssignDependTwoIfNet(Cell): + def __init__(self): + super().__init__() + self.parameter1 = Parameter( + Tensor([3.0], ms.float32), name="parameter1") + self.assign = P.Assign() + self.mul = P.Mul() + self.addn = P.AddN() + self.depend = P.Depend() + + def construct(self, x, y): + self.assign(self.parameter1, x) + if self.parameter1 > y: + x = self.mul(x, x) + p2 = self.assign(self.parameter1, x) + if self.parameter1 > y: + x = self.addn((x, self.parameter1)) + p3 = self.assign(self.parameter1, x) + self.depend(p3, p2) + return x + + def grad_mindspore_impl(self, params1, params2, grad_ys): + grad_net = GradOfAllInputsAndParams(self) + grad_net.set_train() + grad_out = grad_net(params1, params2, grad_ys) + return grad_out + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_side_effect_grad_control_flow_assign_depend_of_two_if(): + net = SideEffectControlFlowAssignDependTwoIfNet() + grad_ys = Tensor([18.0], ms.float32) + inputs1 = Tensor([9.0], ms.float32) + inputs2 = Tensor([6.0], ms.float32) + net.grad_mindspore_impl(inputs1, inputs2, grad_ys) + + +class SideEffectTwoAddnSwitchNet(Cell): + def __init__(self): + super().__init__() + self.addN = P.AddN() + + def construct(self, x): + y = x + x = self.addN((x, x, x)) + y = self.addN((y, y)) + if x > y: + return x + return y + + def grad_mindspore_impl(self, params, grad_ys): + grad_net = GradOfAllInputsAndParams(self) + grad_net.set_train() + grad_out = grad_net(params, grad_ys) + return grad_out + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_side_effect_grad_two_addn_switch(): + net = SideEffectTwoAddnSwitchNet() + grad_ys = Tensor([18.0], ms.float32) + inputs = Tensor([9.0], ms.float32) + out1 = net.grad_mindspore_impl(inputs, grad_ys) + net = SideEffectTwoAddnSwitchNet() + context.set_context(mode=context.PYNATIVE_MODE) + out2 = net.grad_mindspore_impl(inputs, grad_ys) + allclose_nparray(out1[0][0].asnumpy(), out2[0][0].asnumpy(), 0.001, 0.001) + + +class SideEffectGradIfNet(Cell): + def __init__(self): + super().__init__() + self.relu = P.ReLU() + a = np.full((1,), 5, dtype=np.float32) + self.a = Parameter(Tensor(a), name="a") + b = np.full((1,), 4, dtype=np.float32) + self.b = Parameter(Tensor(b), name="b") + + def construct(self, x): + if self.a > self.b: + x = self.relu(x) + out = x + else: + out = x + 2 + return out + + def grad_mindspore_impl(self, params, grad_ys): + grad_net = GradOfFirstInput(self) + grad_net.set_train() + grad_out = grad_net(params, grad_ys) + return grad_out + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_side_effect_grad_if(): + context.set_context(mode=context.GRAPH_MODE) + net = SideEffectGradIfNet() + grad_ys = Tensor([18.0], ms.float32) + inputs = Tensor([9.0], ms.float32) + out1 = net.grad_mindspore_impl(inputs, grad_ys) + net = SideEffectGradIfNet() + context.set_context(mode=context.PYNATIVE_MODE) + out2 = net.grad_mindspore_impl(inputs, grad_ys) + allclose_nparray(out1.asnumpy(), out2.asnumpy(), 0.001, 0.001) + + +class OneInputBprop(Cell): + def __init__(self): + super().__init__() + self.op = P.ReLU() + + def construct(self, x): + return self.op(x) + + def bprop(self, x, out, dout): + return (5 * x,) + + +class HighGrad(Cell): + def __init__(self, network, grad_list, sens_param=False, real_inputs_count=None): + super().__init__() + self.grads = [network] + for i in range(len(grad_list)-1): + _grad = grad_list[i](self.grads[i], sens_param=False) + self.grads.append(_grad) + self.final_grad = grad_list[-1](self.grads[-1], + sens_param=sens_param, real_inputs_count=real_inputs_count) + + def construct(self, *inputs): + return self.final_grad(*inputs) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_highgrad_one_input_sec_grad(): + net = OneInputBprop() + x = Tensor(np.array([2, 2]).astype(np.float32)) + grad_net = HighGrad(net, [GradOfFirstInput, GradOfFirstInput]) + dxdx = grad_net(x) + assert (dxdx.asnumpy() == np.array([5, 5]).astype(np.float32)).all() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_highgrad_one_input_third_grad(): + net = OneInputBprop() + x = Tensor(np.array([2, 2]).astype(np.float32)) + grad_net = HighGrad( + net, [GradOfFirstInput, GradOfFirstInput, GradOfFirstInput]) + third_grad = grad_net(x) + assert (third_grad.asnumpy() == np.array([0, 0]).astype(np.float32)).all() diff --git a/tests/st/auto_monad/test_auto_monad_momentum_loss.py b/tests/st/auto_monad/test_auto_monad_momentum_loss.py new file mode 100644 index 0000000000..e86a8f2590 --- /dev/null +++ b/tests/st/auto_monad/test_auto_monad_momentum_loss.py @@ -0,0 +1,79 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import pytest +import numpy as np +import mindspore.ops.operations as P +from mindspore.common.parameter import Parameter +from mindspore import context +from mindspore import Tensor +from mindspore.nn import Cell +from mindspore.nn.optim import Momentum +from mindspore.nn.wrap.cell_wrapper import WithLossCell +from mindspore.nn.wrap.loss_scale import TrainOneStepWithLossScaleCell +from mindspore.ops import functional as F +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(Cell): + def __init__(self, in_features, out_features): + super(Net, self).__init__() + self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight") + self.bias = Parameter(Tensor(np.ones([out_features]).astype(np.float32)), name="bias") + self.matmul = P.MatMul() + self.add = P.Add() + + def construct(self, input_): + output = self.add(self.matmul(input_, self.weight), self.bias) + return output + + +def get_axis(x): + shape_op = P.Shape() + shape = shape_op(x) + length = F.tuple_len(shape) + perm = F.make_range(0, length) + return perm + + +class MSELoss(Cell): + def __init__(self): + super(MSELoss, self).__init__() + self.reduce_sum = P.ReduceSum() + self.square = P.Square() + self.reduce_mean = P.ReduceMean() + + def construct(self, data, label): + diff = data - label + return self.reduce_mean(self.square(diff), get_axis(diff)) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_momentum_loss(): + inputs = Tensor(np.ones([15, 1]).astype(np.float32)) + label = Tensor(np.zeros([15, 1]).astype(np.float32)) + net = Net(1, 1) + loss = MSELoss() + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + net_with_loss = WithLossCell(net, loss) + train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer, + scale_sense=Tensor(np.full((1), 1.0), dtype=mstype.float32)) + train_network.set_train() + output = train_network(inputs, label) + print("the result is ", output) diff --git a/tests/st/auto_monad/test_effect_ops.py b/tests/st/auto_monad/test_effect_ops.py new file mode 100644 index 0000000000..10c832df60 --- /dev/null +++ b/tests/st/auto_monad/test_effect_ops.py @@ -0,0 +1,397 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +import tempfile +import pytest +import numpy as np +import mindspore.nn as nn +import mindspore.ops.operations as P +from mindspore import context, Tensor +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter +from mindspore.train.summary.summary_record import SummaryRecord +from tests.summary_utils import SummaryReader + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class AssignAddNet(nn.Cell): + def __init__(self, para): + super(AssignAddNet, self).__init__() + self.para = Parameter(para, name="para") + self.assign_add = P.AssignAdd() + + def construct(self, value): + self.assign_add(self.para, value) + return self.para + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_assign_add(): + x = Tensor(1, dtype=mstype.int32) + y = Tensor(2, dtype=mstype.int32) + expect = Tensor(3, dtype=mstype.int32) + net = AssignAddNet(x) + out = net(y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +class AssignSubNet(nn.Cell): + def __init__(self, para): + super(AssignSubNet, self).__init__() + self.para = Parameter(para, name="para") + self.assign_sub = P.AssignSub() + + def construct(self, value): + self.assign_sub(self.para, value) + return self.para + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_assign_sub(): + x = Tensor(3, dtype=mstype.int32) + y = Tensor(2, dtype=mstype.int32) + expect = Tensor(1, dtype=mstype.int32) + net = AssignSubNet(x) + out = net(y) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +class ScatterAddNet(nn.Cell): + def __init__(self, input_x): + super(ScatterAddNet, self).__init__() + self.input_x = Parameter(input_x, name="para") + self.scatter_add = P.ScatterAdd() + + def construct(self, indices, updates): + self.scatter_add(self.input_x, indices, updates) + return self.input_x + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_add(): + input_x = Tensor(np.array([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]), mstype.float32) + indices = Tensor(np.array([[0, 1], [1, 1]]), mstype.int32) + updates = Tensor(np.ones([2, 2, 3]), mstype.float32) + expect = Tensor(np.array([[1.0, 1.0, 1.0], [3.0, 3.0, 3.0]]), mstype.float32) + net = ScatterAddNet(input_x) + out = net(indices, updates) + np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy()) + + +class ScatterSubNet(nn.Cell): + def __init__(self, input_x): + super(ScatterSubNet, self).__init__() + self.input_x = Parameter(input_x, name="para") + self.scatter_sub = P.ScatterSub() + + def construct(self, indices, updates): + self.scatter_sub(self.input_x, indices, updates) + return self.input_x + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_sub(): + input_x = Tensor(np.array([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]), mstype.float32) + indices = Tensor(np.array([[0, 1]]), mstype.int32) + updates = Tensor(np.array([[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]]), mstype.float32) + expect = Tensor(np.array([[-1.0, -1.0, -1.0], [-1.0, -1.0, -1.0]]), mstype.float32) + net = ScatterSubNet(input_x) + out = net(indices, updates) + np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy()) + + +class ScatterMulNet(nn.Cell): + def __init__(self, input_x): + super(ScatterMulNet, self).__init__() + self.input_x = Parameter(input_x, name="para") + self.scatter_mul = P.ScatterMul() + + def construct(self, indices, updates): + self.scatter_mul(self.input_x, indices, updates) + return self.input_x + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_mul(): + input_x = Tensor(np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]), mstype.float32) + indices = Tensor(np.array([[0, 1]]), mstype.int32) + updates = Tensor(np.array([[[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]]), mstype.float32) + expect = Tensor(np.array([[2.0, 2.0, 2.0], [4.0, 4.0, 4.0]]), mstype.float32) + net = ScatterMulNet(input_x) + out = net(indices, updates) + np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy()) + + +class ScatterDivNet(nn.Cell): + def __init__(self, input_x): + super(ScatterDivNet, self).__init__() + self.input_x = Parameter(input_x, name="para") + self.scatter_div = P.ScatterDiv() + + def construct(self, indices, updates): + self.scatter_div(self.input_x, indices, updates) + return self.input_x + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_div(): + input_x = Tensor(np.array([[6.0, 6.0, 6.0], [2.0, 2.0, 2.0]]), mstype.float32) + indices = Tensor(np.array([[0, 1]]), mstype.int32) + updates = Tensor(np.array([[[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]]), mstype.float32) + expect = Tensor(np.array([[3.0, 3.0, 3.0], [1.0, 1.0, 1.0]]), mstype.float32) + net = ScatterDivNet(input_x) + out = net(indices, updates) + np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy()) + + +class ScatterMaxNet(nn.Cell): + def __init__(self, input_x): + super(ScatterMaxNet, self).__init__() + self.input_x = Parameter(input_x, name="para") + self.scatter_max = P.ScatterMax() + + def construct(self, indices, updates): + self.scatter_max(self.input_x, indices, updates) + return self.input_x + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_max(): + input_x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), mstype.float32) + indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32) + updates = Tensor(np.ones([2, 2, 3]) * 88, mstype.float32) + expect = Tensor(np.array([[88.0, 88.0, 88.0], [88.0, 88.0, 88.0]]), mstype.float32) + net = ScatterMaxNet(input_x) + out = net(indices, updates) + np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy()) + + +class ScatterMinNet(nn.Cell): + def __init__(self, input_x): + super(ScatterMinNet, self).__init__() + self.input_x = Parameter(input_x, name="para") + self.scatter_min = P.ScatterMin() + + def construct(self, indices, updates): + self.scatter_min(self.input_x, indices, updates) + return self.input_x + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_min(): + input_x = Tensor(np.array([[0.0, 1.0, 2.0], [0.0, 0.0, 0.0]]), mstype.float32) + indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32) + updates = Tensor(np.ones([2, 2, 3]), mstype.float32) + expect = Tensor(np.array([[0.0, 1.0, 1.0], [0.0, 0.0, 0.0]]), mstype.float32) + net = ScatterMinNet(input_x) + out = net(indices, updates) + np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy()) + + +class ScatterUpdateNet(nn.Cell): + def __init__(self, input_x): + super(ScatterUpdateNet, self).__init__() + self.input_x = Parameter(input_x, name="para") + self.scatter_update = P.ScatterUpdate() + + def construct(self, indices, updates): + self.scatter_update(self.input_x, indices, updates) + return self.input_x + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_update(): + input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32) + indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32) + updates = Tensor(np.array([[[1.0, 2.2, 1.0], [2.0, 1.2, 1.0]], [[2.0, 2.2, 1.0], [3.0, 1.2, 1.0]]]), mstype.float32) + expect = Tensor(np.array([[2.0, 1.2, 1.0], [3.0, 1.2, 1.0]]), mstype.float32) + net = ScatterUpdateNet(input_x) + out = net(indices, updates) + np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy()) + + +class ScatterNdAddNet(nn.Cell): + def __init__(self, input_x): + super(ScatterNdAddNet, self).__init__() + self.input_x = Parameter(input_x, name="para") + self.scatter_nd_add = P.ScatterNdAdd() + + def construct(self, indices, updates): + self.scatter_nd_add(self.input_x, indices, updates) + return self.input_x + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_nd_add(): + input_x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mstype.float32) + indices = Tensor(np.array([[2], [4], [1], [7]]), mstype.int32) + updates = Tensor(np.array([6, 7, 8, 9]), mstype.float32) + expect = Tensor(np.array([1, 10, 9, 4, 12, 6, 7, 17]), mstype.float32) + net = ScatterNdAddNet(input_x) + out = net(indices, updates) + np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy()) + + +class ScatterNdSubNet(nn.Cell): + def __init__(self, input_x): + super(ScatterNdSubNet, self).__init__() + self.input_x = Parameter(input_x, name="para") + self.scatter_nd_sub = P.ScatterNdSub() + + def construct(self, indices, updates): + self.scatter_nd_sub(self.input_x, indices, updates) + return self.input_x + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_nd_sub(): + input_x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mstype.float32) + indices = Tensor(np.array([[2], [4], [1], [7]]), mstype.int32) + updates = Tensor(np.array([6, 7, 8, 9]), mstype.float32) + expect = Tensor(np.array([1, -6, -3, 4, -2, 6, 7, -1]), mstype.float32) + net = ScatterNdSubNet(input_x) + out = net(indices, updates) + np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy()) + + +class ScatterNdUpdateNet(nn.Cell): + def __init__(self, input_x): + super(ScatterNdUpdateNet, self).__init__() + self.input_x = Parameter(input_x, name="para") + self.scatter_nd_update = P.ScatterNdUpdate() + + def construct(self, indices, updates): + self.scatter_nd_update(self.input_x, indices, updates) + return self.input_x + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_nd_update(): + input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]), mstype.float32) + indices = Tensor(np.array([[0, 0], [1, 1]]), mstype.int32) + updates = Tensor(np.array([1.0, 2.2]), mstype.float32) + expect = Tensor(np.array([[1., 0.3, 3.6], [0.4, 2.2, -3.2]]), mstype.float32) + net = ScatterNdUpdateNet(input_x) + out = net(indices, updates) + np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy()) + + +class ScatterNonAliasingAddNet(nn.Cell): + def __init__(self, input_x): + super(ScatterNonAliasingAddNet, self).__init__() + self.input_x = Parameter(input_x, name="para") + self.scatter_non_aliasing_add = P.ScatterNonAliasingAdd() + + def construct(self, indices, updates): + out = self.scatter_non_aliasing_add(self.input_x, indices, updates) + return out + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_scatter_non_aliasing_add(): + input_x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mstype.float32) + indices = Tensor(np.array([[2], [4], [1], [7]]), mstype.int32) + updates = Tensor(np.array([6, 7, 8, 9]), mstype.float32) + expect = Tensor(np.array([1.0, 10.0, 9.0, 4.0, 12.0, 6.0, 7.0, 17.0]), mstype.float32) + net = ScatterNonAliasingAddNet(input_x) + out = net(indices, updates) + np.testing.assert_almost_equal(out.asnumpy(), expect.asnumpy()) + + +class SummaryNet(nn.Cell): + def __init__(self): + super().__init__() + self.scalar_summary = P.ScalarSummary() + self.image_summary = P.ImageSummary() + self.tensor_summary = P.TensorSummary() + self.histogram_summary = P.HistogramSummary() + + def construct(self, image_tensor): + self.image_summary("image", image_tensor) + self.tensor_summary("tensor", image_tensor) + self.histogram_summary("histogram", image_tensor) + scalar = image_tensor[0][0][0][0] + self.scalar_summary("scalar", scalar) + return scalar + + +def train_summary_record(test_writer, steps): + """Train and record summary.""" + net = SummaryNet() + out_me_dict = {} + for i in range(0, steps): + image_tensor = Tensor(np.array([[[[i]]]]).astype(np.float32)) + out_put = net(image_tensor) + test_writer.record(i) + out_me_dict[i] = out_put.asnumpy() + return out_me_dict + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_summary(): + with tempfile.TemporaryDirectory() as tmp_dir: + steps = 2 + with SummaryRecord(tmp_dir) as test_writer: + train_summary_record(test_writer, steps=steps) + + file_name = os.path.realpath(test_writer.full_file_name) + with SummaryReader(file_name) as summary_writer: + for _ in range(steps): + event = summary_writer.read_event() + tags = set(value.tag for value in event.summary.value) + assert tags == {'tensor', 'histogram', 'scalar', 'image'} diff --git a/tests/st/auto_monad/test_effect_optimizer.py b/tests/st/auto_monad/test_effect_optimizer.py new file mode 100644 index 0000000000..a85fcff4dc --- /dev/null +++ b/tests/st/auto_monad/test_effect_optimizer.py @@ -0,0 +1,838 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import pytest +import numpy as np +import mindspore.nn as nn +from mindspore import context, Tensor +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class AdamNet(nn.Cell): + def __init__(self, var, m, v): + super(AdamNet, self).__init__() + self.apply_adam = P.Adam() + self.var = Parameter(var, name="var") + self.m = Parameter(m, name="m") + self.v = Parameter(v, name="v") + + def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad): + self.apply_adam(self.var, self.m, self.v, beta1_power, + beta2_power, lr, beta1, beta2, epsilon, grad) + return self.var, self.m, self.v + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_apply_adam(): + var = Tensor(np.ones([3, 3, 3]).astype(np.float32)) + m = Tensor(np.ones([3, 3, 3]).astype(np.float32)) + v = Tensor(np.ones([3, 3, 3]).astype(np.float32)) + net = AdamNet(var, m, v) + + beta1_power = Tensor(0.9, mstype.float32) + beta2_power = Tensor(0.999, mstype.float32) + lr = Tensor(0.001, mstype.float32) + beta1 = Tensor(0.9, mstype.float32) + beta2 = Tensor(0.999, mstype.float32) + epsilon = Tensor(1e-8, mstype.float32) + grad = Tensor(np.random.rand(3, 3, 3).astype(np.float32)) + new_var, new_m, new_v = net( + beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad) + assert ((new_var != var).any() and (new_m != m).any() and (new_v != v).any()), \ + "The results should be different!" + + +class ApplyAdaMaxNet(nn.Cell): + def __init__(self, val, m, v): + super(ApplyAdaMaxNet, self).__init__() + self.apply_ada_max = P.ApplyAdaMax() + self.var = Parameter(val, name="var") + self.m = Parameter(m, name="m") + self.v = Parameter(v, name="v") + + def construct(self, beta1_power, lr, beta1, beta2, epsilon, grad): + self.apply_ada_max(self.var, self.m, self.v, + beta1_power, lr, beta1, beta2, epsilon, grad) + return self.var, self.m, self.v + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_apply_ada_max(): + var = Tensor(np.random.rand(3, 3).astype(np.float32)) + m = Tensor(np.random.rand(3, 3).astype(np.float32)) + v = Tensor(np.random.rand(3, 3).astype(np.float32)) + net = ApplyAdaMaxNet(var, m, v) + + beta1_power = Tensor(0.9, mstype.float32) + lr = Tensor(0.001, mstype.float32) + beta1 = Tensor(0.9, mstype.float32) + beta2 = Tensor(0.99, mstype.float32) + epsilon = Tensor(1e-10, mstype.float32) + grad = Tensor(np.random.rand(3, 3).astype(np.float32)) + new_var, new_m, new_v = net(beta1_power, lr, beta1, beta2, epsilon, grad) + assert ((new_var != var).any() and (new_m != m).any() and (new_v != v).any()), \ + "The results should be different!" + + +class ApplyAdadeltaNet(nn.Cell): + def __init__(self, var, accum, accum_update): + super(ApplyAdadeltaNet, self).__init__() + self.apply_adadelta = P.ApplyAdadelta() + self.var = Parameter(var, name="var") + self.accum = Parameter(accum, name="accum") + self.accum_update = Parameter(accum_update, name="accum_update") + + def construct(self, lr, rho, epsilon, grad): + self.apply_adadelta(self.var, self.accum, + self.accum_update, lr, rho, epsilon, grad) + return self.var, self.accum, self.accum_update + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_apply_adadelta(): + var = Tensor(np.random.rand(3, 3).astype(np.float32)) + accum = Tensor(np.random.rand(3, 3).astype(np.float32)) + accum_update = Tensor(np.random.rand(3, 3).astype(np.float32)) + net = ApplyAdadeltaNet(var, accum, accum_update) + + lr = Tensor(0.001, mstype.float32) + rho = Tensor(0.0, mstype.float32) + epsilon = Tensor(1e-6, mstype.float32) + grad = Tensor(np.random.rand(3, 3).astype(np.float32)) + new_var, new_accum, new_accum_update = net(lr, rho, epsilon, grad) + assert ((new_var != var).any() and (new_accum != accum).any() and (new_accum_update != accum_update).any()), \ + "The results should be different!" + + +class ApplyAdagrad(nn.Cell): + def __init__(self, var, accum): + super(ApplyAdagrad, self).__init__() + self.apply_adagrad = P.ApplyAdagrad() + self.var = Parameter(var, name="var") + self.accum = Parameter(accum, name="accum") + + def construct(self, lr, grad): + self.apply_adagrad(self.var, self.accum, lr, grad) + return self.var, self.accum + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_apply_adagrad(): + var = Tensor(np.random.rand(3, 3).astype(np.float32)) + accum = Tensor(np.random.rand(3, 3).astype(np.float32)) + net = ApplyAdagrad(var, accum) + + lr = Tensor(0.001, mstype.float32) + grad = Tensor(np.random.rand(3, 3).astype(np.float32)) + new_var, new_accum = net(lr, grad) + assert ((new_var != var).any() and (new_accum != accum).any()), \ + "The results should be different!" + + +class ApplyAdagradV2Net(nn.Cell): + def __init__(self, var, accum): + super(ApplyAdagradV2Net, self).__init__() + self.apply_adagrad_v2 = P.ApplyAdagradV2(epsilon=1e-6) + self.var = Parameter(var, name="var") + self.accum = Parameter(accum, name="accum") + + def construct(self, lr, grad): + self.apply_adagrad_v2(self.var, self.accum, lr, grad) + return self.var, self.accum + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_apply_adagrad_v2(): + var = Tensor(np.random.rand(3, 3).astype(np.float32)) + accum = Tensor(np.random.rand(3, 3).astype(np.float32)) + net = ApplyAdagradV2Net(var, accum) + + lr = Tensor(0.001, mstype.float32) + grad = Tensor(np.random.rand(3, 3).astype(np.float32)) + new_var, new_accum = net(lr, grad) + assert ((new_var != var).any() and (new_accum != accum).any()), \ + "The results should be different!" + + +class ApplyAddSignNet(nn.Cell): + def __init__(self, var, m): + super(ApplyAddSignNet, self).__init__() + self.apply_add_sign = P.ApplyAddSign() + self.var = Parameter(var, name="var") + self.m = Parameter(m, name="m") + + def construct(self, lr, alpha, sign_decay, beta, grad): + self.apply_add_sign(self.var, self.m, lr, alpha, + sign_decay, beta, grad) + return self.var, self.m + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_apply_add_sign(): + var = Tensor(np.random.rand(3, 3).astype(np.float32)) + m = Tensor(np.random.rand(3, 3).astype(np.float32)) + net = ApplyAddSignNet(var, m) + + lr = Tensor(0.001, mstype.float32) + alpha = Tensor(1.0, mstype.float32) + sign_decay = Tensor(0.99, mstype.float32) + beta = Tensor(0.9, mstype.float32) + grad = Tensor(np.random.rand(3, 3).astype(np.float32)) + new_var, new_m = net(lr, alpha, sign_decay, beta, grad) + assert ((new_var != var).any() and (new_m != m).any()), \ + "The results should be different!" + + +class ApplyCenteredRMSPropNet(nn.Cell): + def __init__(self, var): + super(ApplyCenteredRMSPropNet, self).__init__() + self.apply_centered_rms_prop = P.ApplyCenteredRMSProp() + self.var = Parameter(var, name="var") + + def construct(self, mean_grad, mean_square, moment, grad, learning_rate): + self.apply_centered_rms_prop(self.var, mean_grad, mean_square, moment, grad, + learning_rate, 0.0, 1e-10, 0.05) + return self.var + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_apply_centered_rms_prop(): + var = Tensor( + np.arange(-6, 6).astype(np.float32).reshape(2, 3, 2), mstype.float32) + net = ApplyCenteredRMSPropNet(var) + + mean_grad = Tensor(np.arange(12).astype( + np.float32).reshape(2, 3, 2), mstype.float32) + mean_square = Tensor( + np.arange(-8, 4).astype(np.float32).reshape(2, 3, 2), mstype.float32) + moment = Tensor(np.arange(12).astype( + np.float32).reshape(2, 3, 2), mstype.float32) + grad = Tensor(np.arange(12).astype( + np.float32).reshape(2, 3, 2), mstype.float32) + learning_rate = Tensor(0.9, mstype.float32) + new_var = net(mean_grad, mean_square, moment, grad, learning_rate) + assert (new_var != var).any(), "The results should be different!" + + +class ApplyFtrlNet(nn.Cell): + def __init__(self, var, accum, linear): + super(ApplyFtrlNet, self).__init__() + self.apply_ftrl = P.ApplyFtrl() + self.var = Parameter(var, name="var") + self.accum = Parameter(accum, name="accum") + self.linear = Parameter(linear, name="linear") + + def construct(self, grad, lr, l1, l2, lr_power): + self.apply_ftrl(self.var, self.accum, self.linear, + grad, lr, l1, l2, lr_power) + return self.var, self.accum, self.linear + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_apply_ftrl(): + var = Tensor(np.random.rand(3, 3).astype(np.float32)) + accum = Tensor(np.random.rand(3, 3).astype(np.float32)) + linear = Tensor(np.random.rand(3, 3).astype(np.float32)) + net = ApplyFtrlNet(var, accum, linear) + + grad = Tensor(np.random.randint(-4, 4, (3, 3)), mstype.float32) + lr = Tensor(0.001, mstype.float32) + l1 = Tensor(0.0, mstype.float32) + l2 = Tensor(0.0, mstype.float32) + lr_power = Tensor(-0.5, mstype.float32) + new_var, new_accum, new_linear = net(grad, lr, l1, l2, lr_power) + assert ((new_var != var).any() and (new_accum != accum).any() and (new_linear != linear).any()), \ + "The results should be different!" + + +class ApplyGradientDescentNet(nn.Cell): + def __init__(self, var): + super(ApplyGradientDescentNet, self).__init__() + self.apply_gradient_descent = P.ApplyGradientDescent() + self.var = Parameter(var, name="var") + + def construct(self, alpha, delta): + self.apply_gradient_descent(self.var, alpha, delta) + return self.var + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_apply_gradient_descent(): + var = Tensor(np.random.rand(3, 3).astype(np.float32)) + net = ApplyGradientDescentNet(var) + + alpha = Tensor(0.001, mstype.float32) + delta = Tensor(np.random.rand(3, 3).astype(np.float32)) + new_var = net(alpha, delta) + assert (new_var != var).any(), "The results should be different!" + + +class ApplyMomentumNet(nn.Cell): + def __init__(self, var, accum): + super(ApplyMomentumNet, self).__init__() + self.apply_momentum = P.ApplyMomentum(gradient_scale=1024.0) + self.var = Parameter(var, name='var') + self.accum = Parameter(accum, name='accum') + + def construct(self, lr, grad, momentum): + self.apply_momentum(self.var, self.accum, lr, grad, momentum) + return self.var, self.accum + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_apply_momentum(): + var = Tensor(np.random.normal(size=(2, 3, 3, 4)).astype(np.float32)) + accum = Tensor(np.random.normal(size=(2, 3, 3, 4)).astype(np.float32)) + net = ApplyMomentumNet(var, accum) + + lr = Tensor(np.random.normal(size=(1,)).astype(np.float32)) + grad = Tensor(np.random.normal(size=(2, 3, 3, 4)).astype(np.float32)) + momentum = Tensor(np.random.normal(size=(1,)).astype(np.float32)) + new_var, new_accum = net(lr, grad, momentum) + assert ((new_var != var).any() and (new_accum != accum).any()), \ + "The results should be different!" + + +class ApplyPowerSignNet(nn.Cell): + def __init__(self, var, m): + super(ApplyPowerSignNet, self).__init__() + self.apply_power_sign = P.ApplyPowerSign() + self.var = Parameter(var, name="var") + self.m = Parameter(m, name="m") + + def construct(self, lr, logbase, sign_decay, beta, grad): + self.apply_power_sign(self.var, self.m, lr, + logbase, sign_decay, beta, grad) + return self.var, self.m + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_apply_power_sign(): + var = Tensor(np.random.rand(3, 3).astype(np.float32)) + m = Tensor(np.random.rand(3, 3).astype(np.float32)) + net = ApplyPowerSignNet(var, m) + + lr = Tensor(0.001, mstype.float32) + logbase = Tensor(np.e, mstype.float32) + sign_decay = Tensor(0.99, mstype.float32) + beta = Tensor(0.9, mstype.float32) + grad = Tensor(np.random.rand(3, 3).astype(np.float32)) + new_var, new_m = net(lr, logbase, sign_decay, beta, grad) + assert ((new_var != var).any() and (new_m != m).any()), \ + "The results should be different!" + + +class ApplyProximalAdagradNet(nn.Cell): + def __init__(self, var, accum): + super(ApplyProximalAdagradNet, self).__init__() + self.apply_proximal_adagrad = P.ApplyProximalAdagrad() + self.var = Parameter(var, name="var") + self.accum = Parameter(accum, name='accum') + + def construct(self, lr, l1, l2, grad): + self.apply_proximal_adagrad(self.var, self.accum, lr, l1, l2, grad) + return self.var, self.accum + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_apply_proximal_adagrad(): + var = Tensor(np.random.rand(3, 3).astype(np.float32)) + accum = Tensor(np.random.rand(3, 3).astype(np.float32)) + net = ApplyProximalAdagradNet(var, accum) + + lr = Tensor(0.01, mstype.float32) + l1 = Tensor(0.0, mstype.float32) + l2 = Tensor(0.0, mstype.float32) + grad = Tensor(np.random.rand(3, 3).astype(np.float32)) + new_var, new_accum = net(lr, l1, l2, grad) + assert ((new_var != var).any() and (new_accum != accum).any()), \ + "The results should be different!" + + +class ApplyProximalGradientDescentNet(nn.Cell): + def __init__(self, var): + super(ApplyProximalGradientDescentNet, self).__init__() + self.apply_proximal_gradient_descent = P.ApplyProximalGradientDescent() + self.var = Parameter(var, name="var") + + def construct(self, alpha, l1, l2, delta): + self.apply_proximal_gradient_descent(self.var, alpha, l1, l2, delta) + return self.var + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_apply_proximal_gradient_descent(): + var = Tensor(np.random.rand(3, 3).astype(np.float32)) + net = ApplyProximalGradientDescentNet(var) + + alpha = Tensor(0.001, mstype.float32) + l1 = Tensor(0.0, mstype.float32) + l2 = Tensor(0.0, mstype.float32) + delta = Tensor(np.random.rand(3, 3).astype(np.float32)) + new_var = net(alpha, l1, l2, delta) + assert (new_var != var).any(), "The results should be different!" + + +class ApplyRMSPropNet(nn.Cell): + def __init__(self, var): + super(ApplyRMSPropNet, self).__init__() + self.apply_rms_prop = P.ApplyRMSProp() + self.var = Parameter(var, name="var") + + def construct(self, mean_square, moment, learning_rate, grad): + self.apply_rms_prop(self.var, mean_square, moment, + learning_rate, grad, 0.0, 1e-10, 0.001) + return self.var + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_apply_rms_prop(): + var = Tensor(1., mstype.float32) + net = ApplyRMSPropNet(var) + + mean_square = Tensor(2., mstype.float32) + moment = Tensor(1., mstype.float32) + learning_rate = Tensor(0.9, mstype.float32) + grad = Tensor(2., mstype.float32) + new_var = net(mean_square, moment, learning_rate, grad) + assert (new_var != var).any(), "The results should be different!" + + +class FusedSparseAdamNet(nn.Cell): + def __init__(self, var, m, v): + super(FusedSparseAdamNet, self).__init__() + self.fused_sparse_adam = P.FusedSparseAdam() + self.var = Parameter(var, name="var") + self.m = Parameter(m, name="m") + self.v = Parameter(v, name="v") + + def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, indices): + self.fused_sparse_adam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, beta2, + epsilon, grad, indices) + return self.var, self.m, self.v + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fused_sparse_adam(): + var = Tensor(np.ones([3, 1, 2]).astype(np.float32)) + m = Tensor(np.ones([3, 1, 2]).astype(np.float32)) + v = Tensor(np.ones([3, 1, 2]).astype(np.float32)) + net = FusedSparseAdamNet(var, m, v) + + beta1_power = Tensor(0.9, mstype.float32) + beta2_power = Tensor(0.999, mstype.float32) + lr = Tensor(0.001, mstype.float32) + beta1 = Tensor(0.9, mstype.float32) + beta2 = Tensor(0.999, mstype.float32) + epsilon = Tensor(1e-8, mstype.float32) + gradient = Tensor(np.random.rand(2, 1, 2), mstype.float32) + indices = Tensor([0, 1], mstype.int32) + new_var, new_m, new_v = net( + beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient, indices) + assert ((new_var != var).any() and (new_m != m).any() and (new_v != v).any()), \ + "The results should be different!" + + +class FusedSparseFtrlNet(nn.Cell): + def __init__(self, var, accum, linear): + super(FusedSparseFtrlNet, self).__init__() + self.fused_sparse_ftrl = P.FusedSparseFtrl( + lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5) + self.var = Parameter(var, name="var") + self.accum = Parameter(accum, name="accum") + self.linear = Parameter(linear, name="linear") + + def construct(self, grad, indices): + self.fused_sparse_ftrl(self.var, self.accum, + self.linear, grad, indices) + return self.var, self.accum, self.linear + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fused_sparse_ftrl(): + var = Tensor(np.random.rand(3, 1, 2).astype(np.float32)) + accum = Tensor(np.random.rand(3, 1, 2).astype(np.float32)) + linear = Tensor(np.random.rand(3, 1, 2).astype(np.float32)) + net = FusedSparseFtrlNet(var, accum, linear) + + grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32)) + indices = Tensor(np.array([0, 1]).astype(np.int32)) + new_var, new_accum, new_linear = net(grad, indices) + assert ((new_var != var).any() and (new_accum != accum).any() and (new_linear != linear).any()), \ + "The results should be different!" + + +class FusedSparseLazyAdamNet(nn.Cell): + def __init__(self, var, m, v): + super(FusedSparseLazyAdamNet, self).__init__() + self.fused_sparse_lazyadam = P.FusedSparseLazyAdam() + self.var = Parameter(var, name="var") + self.m = Parameter(m, name="m") + self.v = Parameter(v, name="v") + + def construct(self, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, indices): + self.fused_sparse_lazyadam(self.var, self.m, self.v, beta1_power, beta2_power, lr, beta1, + beta2, epsilon, grad, indices) + return self.var, self.m, self.v + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fused_sparse_lazyadam(): + var = Tensor(np.ones([3, 1, 2]).astype(np.float32)) + m = Tensor(np.ones([3, 1, 2]).astype(np.float32)) + v = Tensor(np.ones([3, 1, 2]).astype(np.float32)) + net = FusedSparseLazyAdamNet(var, m, v) + + beta1_power = Tensor(0.9, mstype.float32) + beta2_power = Tensor(0.999, mstype.float32) + lr = Tensor(0.001, mstype.float32) + beta1 = Tensor(0.9, mstype.float32) + beta2 = Tensor(0.999, mstype.float32) + epsilon = Tensor(1e-8, mstype.float32) + gradient = Tensor(np.random.rand(2, 1, 2), mstype.float32) + indices = Tensor([0, 1], mstype.int32) + new_var, new_m, new_v = net( + beta1_power, beta2_power, lr, beta1, beta2, epsilon, gradient, indices) + assert ((new_var != var).any() and (new_m != m).any() and (new_v != v).any()), \ + "The results should be different!" + + +class FusedSparseProximalAdagradNet(nn.Cell): + def __init__(self, var, accum): + super(FusedSparseProximalAdagradNet, self).__init__() + self.fused_sparse_proximal_adagrad = P.FusedSparseProximalAdagrad() + self.var = Parameter(var, name="var") + self.accum = Parameter(accum, name="accum") + + def construct(self, lr, l1, l2, grad, indices): + self.fused_sparse_proximal_adagrad( + self.var, self.accum, lr, l1, l2, grad, indices) + return self.var, self.accum + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_fused_sparse_proximal_adagrad(): + var = Tensor(np.random.rand(3, 1, 2).astype(np.float32)) + accum = Tensor(np.random.rand(3, 1, 2).astype(np.float32)) + net = FusedSparseProximalAdagradNet(var, accum) + + lr = Tensor(0.01, mstype.float32) + l1 = Tensor(0.0, mstype.float32) + l2 = Tensor(0.0, mstype.float32) + grad = Tensor(np.random.rand(2, 1, 2).astype(np.float32)) + indices = Tensor(np.array([0, 1]).astype(np.int32)) + new_var, new_accum = net(lr, l1, l2, grad, indices) + assert ((new_var != var).any() and (new_accum != accum).any()), \ + "The results should be different!" + + +class SparseApplyAdagradNet(nn.Cell): + def __init__(self, var, accum): + super(SparseApplyAdagradNet, self).__init__() + self.sparse_apply_adagrad = P.SparseApplyAdagrad(lr=0.01) + self.var = Parameter(var, name="var") + self.accum = Parameter(accum, name="accum") + + def construct(self, grad, indices): + self.sparse_apply_adagrad(self.var, self.accum, grad, indices) + return self.var, self.accum + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_sparse_apply_adagrad(): + var = Tensor(np.random.rand(3, 3).astype(np.float32)) + accum = Tensor(np.random.rand(3, 3).astype(np.float32)) + net = SparseApplyAdagradNet(var, accum) + + grad = Tensor(np.random.rand(3, 3).astype(np.float32)) + indices = Tensor(np.ones((3,), np.int32)) + new_var, _ = net(grad, indices) + # new_accum is equal to accum. + assert (new_var != var).any(), "The results should be different!" + + +class SparseApplyAdagradV2Net(nn.Cell): + def __init__(self, var, accum): + super(SparseApplyAdagradV2Net, self).__init__() + self.sparse_apply_adagrad_v2 = P.SparseApplyAdagradV2( + lr=0.01, epsilon=0.001) + self.var = Parameter(var, name="var") + self.accum = Parameter(accum, name="accum") + + def construct(self, grad, indices): + self.sparse_apply_adagrad_v2(self.var, self.accum, grad, indices) + return self.var, self.accum + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_sparse_apply_adagrad_v2(): + var = Tensor(np.random.rand(3, 3).astype(np.float32)) + accum = Tensor(np.random.rand(3, 3).astype(np.float32)) + net = SparseApplyAdagradV2Net(var, accum) + + grad = grad = Tensor(np.random.rand(3, 3).astype(np.float32)) + indices = Tensor(np.ones((3,), np.int32)) + new_var, new_accum = net(grad, indices) + assert ((new_var != var).any() and (new_accum != accum).any()), \ + "The results should be different!" + + +class SparseApplyFtrlNet(nn.Cell): + def __init__(self, var, accum, linear): + super(SparseApplyFtrlNet, self).__init__() + self.sparse_apply_ftrl = P.SparseApplyFtrl( + lr=0.01, l1=0.0, l2=0.0, lr_power=-0.5) + self.var = Parameter(var, name="var") + self.accum = Parameter(accum, name="accum") + self.linear = Parameter(linear, name="linear") + + def construct(self, grad, indices): + self.sparse_apply_ftrl(self.var, self.accum, + self.linear, grad, indices) + return self.var, self.accum, self.linear + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_sparse_apply_ftrl(): + var = Tensor(np.random.rand(3, 3).astype(np.float32)) + accum = Tensor(np.random.rand(3, 3).astype(np.float32)) + linear = Tensor(np.random.rand(3, 3).astype(np.float32)) + net = SparseApplyFtrlNet(var, accum, linear) + + grad = Tensor(np.random.rand(3, 3).astype(np.float32)) + indices = Tensor(np.ones((3,), np.int32)) + new_var, new_accum, new_linear = net(grad, indices) + assert ((new_var != var).any() and (new_accum != accum).any() and (new_linear != linear).any()), \ + "The results should be different!" + + +class SparseApplyFtrlV2Net(nn.Cell): + def __init__(self, var, accum, linear): + super(SparseApplyFtrlV2Net, self).__init__() + self.sparse_apply_ftrl_v2 = P.SparseApplyFtrlV2( + lr=0.01, l1=0.0, l2=0.0, l2_shrinkage=0.0, lr_power=-0.5) + self.var = Parameter(var, name="var") + self.accum = Parameter(accum, name="accum") + self.linear = Parameter(linear, name="linear") + + def construct(self, grad, indices): + self.sparse_apply_ftrl_v2( + self.var, self.accum, self.linear, grad, indices) + return self.var, self.accum, self.linear + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_sparse_apply_ftrl_v2(): + var = Tensor(np.random.rand(3, 3).astype(np.float32)) + accum = Tensor(np.random.rand(3, 3).astype(np.float32)) + linear = Tensor(np.random.rand(3, 3).astype(np.float32)) + net = SparseApplyFtrlV2Net(var, accum, linear) + + grad = Tensor(np.random.rand(3, 3).astype(np.float32)) + indices = Tensor(np.ones((3,), np.int32)) + new_var, new_accum, new_linear = net(grad, indices) + assert ((new_var != var).any() and (new_accum != accum).any() and (new_linear != linear).any()), \ + "The results should be different!" + + +class SparseApplyProximalAdagradNet(nn.Cell): + def __init__(self, var, accum): + super(SparseApplyProximalAdagradNet, self).__init__() + self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad() + self.var = Parameter(var, name="var") + self.accum = Parameter(accum, name="accum") + + def construct(self, lr, l1, l2, grad, indices): + self.sparse_apply_proximal_adagrad( + self.var, self.accum, lr, l1, l2, grad, indices) + return self.var, self.accum + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_sparse_apply_proximal_adagrad(): + var = Tensor(np.random.rand(3, 3).astype(np.float32)) + accum = Tensor(np.random.rand(3, 3).astype(np.float32)) + net = SparseApplyProximalAdagradNet(var, accum) + + lr = Tensor(0.01, mstype.float32) + l1 = Tensor(0.0, mstype.float32) + l2 = Tensor(0.0, mstype.float32) + grad = Tensor(np.random.rand(3, 3).astype(np.float32)) + indices = Tensor(np.ones((3,), np.int32)) + new_var, new_accum = net(lr, l1, l2, grad, indices) + assert ((new_var != var).any() and (new_accum != accum).any()), \ + "The results should be different!" + + +class SGDNet(nn.Cell): + def __init__(self, var): + super(SGDNet, self).__init__() + self.sgd = P.SGD() + self.var = Parameter(var, name="var") + + def construct(self, gradient, learning_rate, accum, momentum, stat): + self.sgd(self.var, gradient, learning_rate, accum, momentum, stat) + return self.var + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_sgd(): + var = Tensor(np.array([2, -0.5, 1.7, 4]), mstype.float32) + net = SGDNet(var) + + gradient = Tensor(np.array([1, -1, 0.5, 2]), mstype.float32) + learning_rate = Tensor(0.01, mstype.float32) + accum = Tensor(np.array([0.1, 0.3, -0.2, -0.1]), mstype.float32) + momentum = Tensor(0.1, mstype.float32) + stat = Tensor(np.array([1.5, -0.3, 0.2, -0.7]), mstype.float32) + new_var = net(gradient, learning_rate, accum, momentum, stat) + assert (new_var != var).any(), "The results should be different!" + + +class ApplyProximalAdagradConstantNet(nn.Cell): + def __init__(self, var, accum): + super().__init__() + self.depend = P.Depend() + self.sparse_apply_proximal_adagrad = P.SparseApplyProximalAdagrad() + self.var = Parameter(var, name="var") + self.accum = Parameter(accum, name="accum") + self.const = Tensor(9999, mstype.float32) + + def construct(self, lr, l1, l2, grad, indices): + optimizer = self.sparse_apply_proximal_adagrad( + self.var, self.accum, lr, l1, l2, grad, indices) + return self.depend(self.const, optimizer) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_sparse_apply_proximal_adagrad_constant(): + var = Tensor(np.random.rand(3, 3).astype(np.float32)) + accum = Tensor(np.random.rand(3, 3).astype(np.float32)) + net = ApplyProximalAdagradConstantNet(var, accum) + lr = Tensor(0.01, mstype.float32) + l1 = Tensor(0.1, mstype.float32) + l2 = Tensor(0.2, mstype.float32) + grad = Tensor(np.random.rand(3, 3).astype(np.float32)) + indices = Tensor(np.ones((3,), np.int32)) + net(lr, l1, l2, grad, indices) + assert (net.parameters_dict()['var'].data != var).any() + assert (net.parameters_dict()['accum'].data != accum).any() + + +class MulSGDNet(nn.Cell): + def __init__(self, var): + super().__init__() + self.sgd = P.SGD() + self.var = Parameter(var, name="var") + self.mul = P.Mul() + + def construct(self, gradient, learning_rate, accum, momentum, stat): + out = self.mul(self.var, self.var) + self.sgd(self.var, gradient, learning_rate, accum, momentum, stat) + return out + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_mul_sgd(): + var = Tensor(np.array([2, -0.5, 1.7, 4]), mstype.float32) + net = MulSGDNet(var) + gradient = Tensor(np.array([1, -1, 0.5, 2]), mstype.float32) + learning_rate = Tensor(0.01, mstype.float32) + accum = Tensor(np.array([0.1, 0.3, -0.2, -0.1]), mstype.float32) + momentum = Tensor(0.1, mstype.float32) + stat = Tensor(np.array([1.5, -0.3, 0.2, -0.7]), mstype.float32) + net(gradient, learning_rate, accum, momentum, stat) + assert (net.parameters_dict()['var'].data != var).any() diff --git a/tests/st/auto_monad/test_effect_random.py b/tests/st/auto_monad/test_effect_random.py new file mode 100644 index 0000000000..a5a3aab2cf --- /dev/null +++ b/tests/st/auto_monad/test_effect_random.py @@ -0,0 +1,432 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import pytest +import numpy as np +import mindspore.nn as nn +import mindspore.ops.operations as P +import mindspore.nn.probability.distribution as msd +from mindspore import context, Tensor +from mindspore.ops import composite as C +from mindspore.common import dtype as mstype +from mindspore import dtype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Sampling(nn.Cell): + """ + Test class: sample of Normal distribution. + """ + def __init__(self, shape, seed=0): + super(Sampling, self).__init__() + self.n1 = msd.Normal(0, 1, seed=seed, dtype=dtype.float32) + self.shape = shape + + def construct(self, mean=None, sd=None): + s1 = self.n1.sample(self.shape, mean, sd) + s2 = self.n1.sample(self.shape, mean, sd) + s3 = self.n1.sample(self.shape, mean, sd) + return s1, s2, s3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_sample_graph(): + shape = (2, 3) + seed = 0 + samp = Sampling(shape, seed=seed) + sample1, sample2, sample3 = samp() + assert ((sample1 != sample2).any() and (sample1 != sample3).any() and (sample2 != sample3).any()), \ + "The results should be different!" + + +class CompositeNormalNet(nn.Cell): + def __init__(self, shape=None, seed=0): + super(CompositeNormalNet, self).__init__() + self.shape = shape + self.seed = seed + + def construct(self, mean, stddev): + s1 = C.normal(self.shape, mean, stddev, self.seed) + s2 = C.normal(self.shape, mean, stddev, self.seed) + s3 = C.normal(self.shape, mean, stddev, self.seed) + return s1, s2, s3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_composite_normal(): + shape = (3, 2, 4) + mean = Tensor(0.0, mstype.float32) + stddev = Tensor(1.0, mstype.float32) + net = CompositeNormalNet(shape) + s1, s2, s3 = net(mean, stddev) + assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \ + "The results should be different!" + + +class CompositeLaplaceNet(nn.Cell): + def __init__(self, shape=None, seed=0): + super(CompositeLaplaceNet, self).__init__() + self.shape = shape + self.seed = seed + + def construct(self, mean, lambda_param): + s1 = C.laplace(self.shape, mean, lambda_param, self.seed) + s2 = C.laplace(self.shape, mean, lambda_param, self.seed) + s3 = C.laplace(self.shape, mean, lambda_param, self.seed) + return s1, s2, s3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_composite_laplace(): + shape = (3, 2, 4) + mean = Tensor(1.0, mstype.float32) + lambda_param = Tensor(1.0, mstype.float32) + net = CompositeLaplaceNet(shape) + s1, s2, s3 = net(mean, lambda_param) + assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \ + "The results should be different!" + + +class CompositeGammaNet(nn.Cell): + def __init__(self, shape=None, seed=0): + super(CompositeGammaNet, self).__init__() + self.shape = shape + self.seed = seed + + def construct(self, alpha, beta): + s1 = C.gamma(self.shape, alpha, beta, self.seed) + s2 = C.gamma(self.shape, alpha, beta, self.seed) + s3 = C.gamma(self.shape, alpha, beta, self.seed) + return s1, s2, s3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_composite_gamma(): + shape = (3, 2, 4) + alpha = Tensor(1.0, mstype.float32) + beta = Tensor(1.0, mstype.float32) + net = CompositeGammaNet(shape) + s1, s2, s3 = net(alpha, beta) + assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \ + "The results should be different!" + + +class CompositePoissonNet(nn.Cell): + def __init__(self, shape=None, seed=0): + super(CompositePoissonNet, self).__init__() + self.shape = shape + self.seed = seed + + def construct(self, mean): + s1 = C.poisson(self.shape, mean, self.seed) + s2 = C.poisson(self.shape, mean, self.seed) + s3 = C.poisson(self.shape, mean, self.seed) + return s1, s2, s3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_composite_poisson(): + shape = (3, 2, 4) + mean = Tensor(2.0, mstype.float32) + net = CompositePoissonNet(shape) + s1, s2, s3 = net(mean) + assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \ + "The results should be different!" + + +class CompositeUniformNet(nn.Cell): + def __init__(self, shape=None, seed=0): + super(CompositeUniformNet, self).__init__() + self.shape = shape + self.seed = seed + + def construct(self, a, b): + s1 = C.uniform(self.shape, a, b, self.seed) + s2 = C.uniform(self.shape, a, b, self.seed) + s3 = C.uniform(self.shape, a, b, self.seed) + return s1, s2, s3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_composite_uniform(): + shape = (3, 2, 4) + a = Tensor(0.0, mstype.float32) + b = Tensor(1.0, mstype.float32) + net = CompositeUniformNet(shape) + s1, s2, s3 = net(a, b) + assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \ + "The results should be different!" + + +class StandardNormalNet(nn.Cell): + def __init__(self, shape, seed=0, seed2=0): + super(StandardNormalNet, self).__init__() + self.shape = shape + self.seed = seed + self.seed2 = seed2 + self.standard_normal = P.StandardNormal(seed, seed2) + + def construct(self): + s1 = self.standard_normal(self.shape) + s2 = self.standard_normal(self.shape) + s3 = self.standard_normal(self.shape) + return s1, s2, s3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_standard_normal(): + shape = (4, 16) + net = StandardNormalNet(shape) + s1, s2, s3 = net() + assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \ + "The results should be different!" + + +class StandardLaplaceNet(nn.Cell): + def __init__(self, shape, seed=0, seed2=0): + super(StandardLaplaceNet, self).__init__() + self.shape = shape + self.seed = seed + self.seed2 = seed2 + self.standard_laplace = P.StandardLaplace(seed, seed2) + + def construct(self): + s1 = self.standard_laplace(self.shape) + s2 = self.standard_laplace(self.shape) + s3 = self.standard_laplace(self.shape) + return s1, s2, s3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_standard_laplace(): + shape = (4, 16) + net = StandardLaplaceNet(shape) + s1, s2, s3 = net() + assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \ + "The results should be different!" + + +class GammaNet(nn.Cell): + def __init__(self, shape, alpha, beta, seed=0, seed2=0): + super(GammaNet, self).__init__() + self.shape = shape + self.alpha = alpha + self.beta = beta + self.seed = seed + self.seed2 = seed2 + self.gamma = P.Gamma(seed, seed2) + + def construct(self): + s1 = self.gamma(self.shape, self.alpha, self.beta) + s2 = self.gamma(self.shape, self.alpha, self.beta) + s3 = self.gamma(self.shape, self.alpha, self.beta) + return s1, s2, s3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_gamma(): + shape = (4, 16) + alpha = Tensor(1.0, mstype.float32) + beta = Tensor(1.0, mstype.float32) + net = GammaNet(shape, alpha, beta) + s1, s2, s3 = net() + assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \ + "The results should be different!" + + +class PoissonNet(nn.Cell): + def __init__(self, shape, seed=0, seed2=0): + super(PoissonNet, self).__init__() + self.shape = shape + self.seed = seed + self.seed2 = seed2 + self.poisson = P.Poisson(seed, seed2) + + def construct(self, mean): + s1 = self.poisson(self.shape, mean) + s2 = self.poisson(self.shape, mean) + s3 = self.poisson(self.shape, mean) + return s1, s2, s3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_poisson(): + shape = (4, 16) + mean = Tensor(5.0, mstype.float32) + net = PoissonNet(shape=shape) + s1, s2, s3 = net(mean) + assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \ + "The results should be different!" + + +class UniformIntNet(nn.Cell): + def __init__(self, shape, seed=0, seed2=0): + super(UniformIntNet, self).__init__() + self.shape = shape + self.seed = seed + self.seed2 = seed2 + self.uniform_int = P.UniformInt(seed, seed2) + + def construct(self, minval, maxval): + s1 = self.uniform_int(self.shape, minval, maxval) + s2 = self.uniform_int(self.shape, minval, maxval) + s3 = self.uniform_int(self.shape, minval, maxval) + return s1, s2, s3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_uniform_int(): + shape = (4, 16) + minval = Tensor(1, mstype.int32) + maxval = Tensor(5, mstype.int32) + net = UniformIntNet(shape) + s1, s2, s3 = net(minval, maxval) + assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \ + "The results should be different!" + + +class UniformRealNet(nn.Cell): + def __init__(self, shape, seed=0, seed2=0): + super(UniformRealNet, self).__init__() + self.shape = shape + self.seed = seed + self.seed2 = seed2 + self.uniform_real = P.UniformReal(seed, seed2) + + def construct(self): + s1 = self.uniform_real(self.shape) + s2 = self.uniform_real(self.shape) + s3 = self.uniform_real(self.shape) + return s1, s2, s3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_uniform_real(): + shape = (4, 16) + net = UniformRealNet(shape) + s1, s2, s3 = net() + assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \ + "The results should be different!" + + +class DropoutGenMaskNet(nn.Cell): + def __init__(self, shape): + super(DropoutGenMaskNet, self).__init__() + self.shape = shape + self.dropout_gen_mask = P.DropoutGenMask(Seed0=0, Seed1=0) + + def construct(self, keep_prob): + s1 = self.dropout_gen_mask(self.shape, keep_prob) + s2 = self.dropout_gen_mask(self.shape, keep_prob) + s3 = self.dropout_gen_mask(self.shape, keep_prob) + return s1, s2, s3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_dropout_gen_mask(): + shape = (2, 4, 5) + keep_prob = Tensor(0.5, mstype.float32) + net = DropoutGenMaskNet(shape) + s1, s2, s3 = net(keep_prob) + assert ((s1 != s2).any() and (s1 != s3).any() and (s2 != s3).any()), \ + "The results should be different!" + + +class RandomChoiceWithMaskNet(nn.Cell): + def __init__(self): + super(RandomChoiceWithMaskNet, self).__init__() + self.rnd_choice_mask = P.RandomChoiceWithMask(count=4, seed=0) + + def construct(self, x): + index1, _ = self.rnd_choice_mask(x) + index2, _ = self.rnd_choice_mask(x) + index3, _ = self.rnd_choice_mask(x) + return index1, index2, index3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_random_choice_with_mask(): + net = RandomChoiceWithMaskNet() + x = Tensor(np.array([[1, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1], [0, 0, 0, 1]]).astype(np.bool)) + index1, index2, index3 = net(x) + assert ((index1 != index2).any() and (index1 != index3).any() and (index2 != index3).any()), \ + "The results should be different!" + + +class RandomCategoricalNet(nn.Cell): + def __init__(self, num_sample): + super(RandomCategoricalNet, self).__init__() + self.random_categorical = P.RandomCategorical(mstype.int64) + self.num_sample = num_sample + + def construct(self, logits, seed=0): + s1 = self.random_categorical(logits, self.num_sample, seed) + s2 = self.random_categorical(logits, self.num_sample, seed) + s3 = self.random_categorical(logits, self.num_sample, seed) + return s1, s2, s3 + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_random_categorical(): + num_sample = 8 + net = RandomCategoricalNet(num_sample) + x = Tensor(np.random.random((10, 5)).astype(np.float32)) + # Outputs may be the same, only basic functions are verified here. + net(x) diff --git a/tests/st/auto_monad/test_float_overflow.py b/tests/st/auto_monad/test_float_overflow.py new file mode 100644 index 0000000000..374cf3dffa --- /dev/null +++ b/tests/st/auto_monad/test_float_overflow.py @@ -0,0 +1,85 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import pytest +import numpy as np +import mindspore.nn as nn +import mindspore.ops.operations as P +import mindspore.ops.functional as F +from mindspore import context, Tensor +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class NpuFloatNet(nn.Cell): + """ NpuFloat definition, base on the related code in test_math_ops.py.""" + + def __init__(self): + super(NpuFloatNet, self).__init__() + self.mul = P.Mul() + self.alloc_status = P.NPUAllocFloatStatus() + self.get_status = P.NPUGetFloatStatus() + self.clear_status = P.NPUClearFloatStatus() + self.fill = P.Fill() + self.shape_op = P.Shape() + self.select = P.Select() + self.less = P.Less() + self.cast = P.Cast() + self.dtype = P.DType() + self.reduce_sum = P.ReduceSum(keep_dims=True) + self.sub = P.Sub() + self.neg = P.Neg() + + def construct(self, x): + init = self.alloc_status() + clear_status = self.clear_status(init) + x = F.depend(x, clear_status) # let x depend on clear_status + res = self.sub(x, self.neg(x)) + init = F.depend(init, res) # let get_status depend on res + get_status = self.get_status(init) + # let reduce_sum depend on get_statusk + init = F.depend(init, get_status) + flag_sum = self.reduce_sum(init, (0,)) + base = self.cast(self.fill(self.dtype( + res), self.shape_op(res), 0.0), self.dtype(flag_sum)) + cond = self.less(base, flag_sum) + out = self.select(cond, self.cast(base, self.dtype(res)), res) + return out + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_float_not_overflow(): + input_data = Tensor(np.full((8, 5, 3, 1), 655, dtype=np.float16), dtype=mstype.float16) + net = NpuFloatNet() + out = net(input_data) + # not overflow, we should got expected output. + expect = Tensor(np.full((8, 5, 3, 1), 655 * 2, + dtype=np.float16), dtype=mstype.float16) + np.testing.assert_array_equal(out.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_float_overflow(): + input_data = Tensor(np.full((8, 5, 3, 1), 65504, dtype=np.float16), dtype=mstype.float16) + net = NpuFloatNet() + out = net(input_data) + # all zero if overflowed. + assert np.all(out.asnumpy() == 0) diff --git a/tests/st/dynamic_shape/test_ftrl.py b/tests/st/dynamic_shape/test_ftrl.py index 9abecac53e..f65734f57e 100644 --- a/tests/st/dynamic_shape/test_ftrl.py +++ b/tests/st/dynamic_shape/test_ftrl.py @@ -72,3 +72,22 @@ def test_lazy_adam_net(): np.allclose(output.asnumpy(), np.array([[[2, 2]], [[2, 2]], [[2, 2]]])) np.allclose(net.weight1.asnumpy(), np.array([[[0.9, 0.9]], [[0.9, 0.9]], [[1.0, 1.0]]])) np.allclose(net.weight2.asnumpy(), np.array([[[0.9, 0.9]], [[0.9, 0.9]], [[0.9, 0.9]]])) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_lazy_adam_net_sparse(): + indices = Tensor(np.array([0, 0, 1]).astype(np.int32)) + label = Tensor(np.zeros([2, 1, 2]).astype(np.float32)) + net = NetWithSparseGatherV2() + + optimizer = LazyAdam(net.trainable_params(), learning_rate=0.1, weight_decay=0.9, loss_scale=2.0) + # will use sparse_opt in LazyAdam + optimizer.target = 'CPU' + train_network = TrainOneStepCell(net, optimizer) + output = train_network(indices, label) + np.allclose(output.asnumpy(), np.array([[[2, 2]], [[2, 2]], [[2, 2]]])) + np.allclose(net.weight1.asnumpy(), np.array([[[0.9, 0.9]], [[0.9, 0.9]], [[1.0, 1.0]]])) + np.allclose(net.weight2.asnumpy(), np.array([[[0.9, 0.9]], [[0.9, 0.9]], [[0.9, 0.9]]])) diff --git a/tests/st/heterogeneous_excutor/test_control.py b/tests/st/heterogeneous_excutor/test_control.py index 189441f1f9..6ab7a24495 100644 --- a/tests/st/heterogeneous_excutor/test_control.py +++ b/tests/st/heterogeneous_excutor/test_control.py @@ -28,14 +28,14 @@ class Net1(nn.Cell): self.relu1 = P.ReLU() self.relu2 = P.ReLU() self.mul = P.Mul() - self.control = P.ControlDepend() + self.depend = P.Depend() def construct(self, x, y): a = self.relu1(x) + y = self.depend(y, a) b = self.relu2(y) c = self.mul(a, b) - e = self.control(a, b) - return c, e + return c, a class Net2(nn.Cell): @@ -44,14 +44,14 @@ class Net2(nn.Cell): self.relu1 = P.ReLU() self.relu2 = P.ReLU().add_prim_attr("primitive_target", "CPU") self.mul = P.Mul() - self.control = P.ControlDepend() + self.depend = P.Depend() def construct(self, x, y): a = self.relu1(x) + y = self.depend(y, a) b = self.relu2(y) c = self.mul(a, b) - e = self.control(a, b) - return c, e + return c, a def test_net(): diff --git a/tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py index c6258bbd35..d299c0dc8f 100644 --- a/tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py @@ -229,7 +229,7 @@ def test_bert_performance(): # assertion occurs while the loss value, overflow state or loss_scale value is wrong loss_value = np.array(callback.loss_list) - expect_loss_value = [11.325791, 11.285011, 11.284766] + expect_loss_value = [11.3660, 11.3265, 11.3264] print("loss value: {}".format(loss_value)) assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) diff --git a/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py index 13da17c0b4..53e49b5882 100644 --- a/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py @@ -223,14 +223,14 @@ def test_bert_precision(enable_graph_kernel=False): # assertion occurs while the loss value, overflow state or loss_scale value is wrong loss_value = np.array(callback.loss_list) - assert np.allclose(loss_value[0], 12.2065868, 0, 0.000001) + assert np.allclose(loss_value[0], 12.2066, 0, 0.0005) if enable_graph_kernel: - expect_loss_value = [12.2065868, 11.8651543, 11.8282356, 11.8266964, 11.8210478, 12.4073524, 12.0055466, - 12.6212320, 12.2229223, 12.4272099] + expect_loss_value = [12.206627, 11.840489, 11.798470, 11.796345, 11.790964, 12.366766, 11.971539, 12.576565, + 12.185522, 12.386192] else: - expect_loss_value = [12.2065868, 11.94102, 11.931558, 11.938105, 11.932648, 12.556579, 12.130686, 12.783716, - 12.360179, 12.578461] + expect_loss_value = [12.206587, 11.966410, 11.965916, 11.975922, 11.970262, 12.608881, 12.174048, 12.840656, + 12.407923, 12.631133] print("loss value: {}".format(loss_value)) assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) diff --git a/tests/st/networks/models/bert/src/bert_for_pre_training.py b/tests/st/networks/models/bert/src/bert_for_pre_training.py index 57612587bf..0125875fd4 100644 --- a/tests/st/networks/models/bert/src/bert_for_pre_training.py +++ b/tests/st/networks/models/bert/src/bert_for_pre_training.py @@ -367,9 +367,8 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): self.cast = P.Cast() self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() + self.clear_status = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() @@ -379,7 +378,6 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), name="loss_scale") - @C.add_flags(has_effect=True) def construct(self, input_ids, input_mask, @@ -404,7 +402,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): scaling_sens = sens # alloc status and clear should be right before gradoperation init = self.alloc_status() - self.clear_before_grad(init) + init = F.depend(init, loss) + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, @@ -418,7 +418,9 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens * self.degree), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - self.get_status(init) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: # sum overflow flag over devices diff --git a/tests/st/networks/models/bert/src/utils.py b/tests/st/networks/models/bert/src/utils.py index 9adda84731..f76604ecfc 100644 --- a/tests/st/networks/models/bert/src/utils.py +++ b/tests/st/networks/models/bert/src/utils.py @@ -44,7 +44,7 @@ def tensor_grad_scale(scale, grad): class BertFinetuneCell(nn.Cell): """ - Especifically defined for finetuning where only four inputs tensor are needed. + Specifically defined for finetuning where only four inputs tensor are needed. """ def __init__(self, network, optimizer, scale_update_cell=None): @@ -68,9 +68,8 @@ class BertFinetuneCell(nn.Cell): self.cast = P.Cast() self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() + self.clear_status = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() @@ -98,28 +97,28 @@ class BertFinetuneCell(nn.Cell): scaling_sens = self.loss_scale else: scaling_sens = sens + init = F.depend(init, loss) + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) grads = self.grad(self.network, weights)(input_ids, input_mask, token_type_id, label_ids, self.cast(scaling_sens, mstype.float32)) - clear_before_grad = self.clear_before_grad(init) - F.control_depend(loss, init) - self.depend_parameter_use(clear_before_grad, scaling_sens) grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) if self.reducer_flag: grads = self.grad_reducer(grads) - flag = self.get_status(init) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) flag_sum = self.reduce_sum(init, (0,)) if self.is_distributed: flag_reduce = self.allreduce(flag_sum) cond = self.less_equal(self.base, flag_reduce) else: cond = self.less_equal(self.base, flag_sum) - F.control_depend(grads, flag) - F.control_depend(flag, flag_sum) overflow = cond if sens is None: overflow = self.loss_scaling_manager(self.loss_scale, cond) @@ -134,7 +133,7 @@ class BertCLSModel(nn.Cell): """ This class is responsible for classification task evaluation, i.e. XNLI(num_labels=3), LCQMC(num_labels=2), Chnsenti(num_labels=2). The returned output represents the final - logits as the results of log_softmax is propotional to that of softmax. + logits as the results of log_softmax is proportional to that of softmax. """ def __init__(self, config, is_training, num_labels=2, dropout_prob=0.0, use_one_hot_embeddings=False): super(BertCLSModel, self).__init__() @@ -162,7 +161,7 @@ class BertCLSModel(nn.Cell): class BertNERModel(nn.Cell): """ This class is responsible for sequence labeling task evaluation, i.e. NER(num_labels=11). - The returned output represents the final logits as the results of log_softmax is propotional to that of softmax. + The returned output represents the final logits as the results of log_softmax is proportional to that of softmax. """ def __init__(self, config, is_training, num_labels=11, use_crf=False, dropout_prob=0.0, use_one_hot_embeddings=False): diff --git a/tests/st/ops/cpu/test_lstm_op.py b/tests/st/ops/cpu/test_lstm_op.py index c630dea1bd..52f61dfbbc 100644 --- a/tests/st/ops/cpu/test_lstm_op.py +++ b/tests/st/ops/cpu/test_lstm_op.py @@ -354,7 +354,7 @@ def test_grad(): input_size = 3 hidden_size = 2 num_layers = 1 - has_bias = False + has_bias = True bidirectional = False dropout = 0.0 net = Grad(Net(seq_len, batch_size, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout)) diff --git a/tests/st/ops/graph_kernel/test_fused_adam.py b/tests/st/ops/graph_kernel/test_fused_adam.py index 851d523dad..3939c80576 100644 --- a/tests/st/ops/graph_kernel/test_fused_adam.py +++ b/tests/st/ops/graph_kernel/test_fused_adam.py @@ -36,9 +36,12 @@ class Net(nn.Cell): self.op_cast = P.Cast() self.op_reshape = P.Reshape() self.op_shape = P.Shape() - self.param = Parameter(Tensor(np.array([1, 3, 5]).astype(np.float32)), name='param') - self.m = Parameter(Tensor(np.array([0.11, 0.33, 0.55]).astype(np.float32)), name='m') - self.v = Parameter(Tensor(np.array([1.2, 3.4, 5.6]).astype(np.float32)), name='v') + self.param = Parameter( + Tensor(np.array([1, 3, 5]).astype(np.float32)), name='param') + self.m = Parameter( + Tensor(np.array([0.11, 0.33, 0.55]).astype(np.float32)), name='m') + self.v = Parameter( + Tensor(np.array([1.2, 3.4, 5.6]).astype(np.float32)), name='v') @ms_function def construct(self, beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr): @@ -48,14 +51,16 @@ class Net(nn.Cell): gradient_fp32 = self.op_cast(gradient, mstype.float32) next_m = self.op_mul(beta1, m_fp32) + \ - self.op_mul(self.op_cast(one_sub_beta_1, mstype.float32), gradient_fp32) + self.op_mul(self.op_cast(one_sub_beta_1, + mstype.float32), gradient_fp32) next_v = self.op_mul(beta2, v_fp32) + self.op_mul(self.op_cast(one_sub_beta_2, mstype.float32), self.op_square(gradient_fp32)) update = next_m / (eps + self.op_sqrt(next_v)) if self.decay_flag: update = self.op_mul(weight_decay_tensor, param_fp32) + update update_with_lr = self.op_mul(lr, update) - next_param = param_fp32 - self.op_reshape(update_with_lr, self.op_shape(param_fp32)) + next_param = param_fp32 - \ + self.op_reshape(update_with_lr, self.op_shape(param_fp32)) depend_v = F.depend(next_param, F.assign(self.param, next_param)) depend_v = F.depend(depend_v, F.assign(self.m, next_m)) @@ -63,6 +68,54 @@ class Net(nn.Cell): return depend_v +class SideEffectFusedAdamNet(nn.Cell): + def __init__(self, decay_flag=True): + super(SideEffectFusedAdamNet, self).__init__() + self.decay_flag = decay_flag + self.op_mul = P.Mul() + self.op_square = P.Square() + self.op_sqrt = P.Sqrt() + self.op_cast = P.Cast() + self.op_reshape = P.Reshape() + self.op_shape = P.Shape() + self.param = Parameter( + Tensor(np.array([0, 0, 0]).astype(np.float32)), name='param') + self.m = Parameter( + Tensor(np.array([0.11, 0.33, 0.55]).astype(np.float32)), name='m') + self.v = Parameter( + Tensor(np.array([1.2, 3.4, 5.6]).astype(np.float32)), name='v') + self.x = Parameter( + Tensor(np.array([1, 3, 5]).astype(np.float32)), name='x') + + @ms_function + def construct(self, beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr): + F.assign(self.param, self.x) + + param_fp32 = self.op_cast(self.param, mstype.float32) + m_fp32 = self.op_cast(self.m, mstype.float32) + v_fp32 = self.op_cast(self.v, mstype.float32) + gradient_fp32 = self.op_cast(gradient, mstype.float32) + + next_m = self.op_mul(beta1, m_fp32) + \ + self.op_mul(self.op_cast(one_sub_beta_1, + mstype.float32), gradient_fp32) + next_v = self.op_mul(beta2, v_fp32) + self.op_mul(self.op_cast(one_sub_beta_2, + mstype.float32), self.op_square(gradient_fp32)) + update = next_m / (eps + self.op_sqrt(next_v)) + if self.decay_flag: + update = self.op_mul(weight_decay_tensor, param_fp32) + update + update_with_lr = self.op_mul(lr, update) + next_param = param_fp32 - \ + self.op_reshape(update_with_lr, self.op_shape(param_fp32)) + + depend_v = F.depend(next_param, F.assign(self.param, next_param)) + depend_v = F.depend(depend_v, F.assign(self.m, next_m)) + depend_v = F.depend(depend_v, F.assign(self.v, next_v)) + + F.assign(self.x, self.m) + return depend_v + + def CalFusedAdam(beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr, param, m, v, is_weight_decay=False): m_expect = beta1 * m + one_sub_beta_1 * gradient @@ -95,9 +148,12 @@ def test_adam(): param_expect, m_expect, v_expect = CalFusedAdam( beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr, param, m, v, is_weight_decay) - assert np.allclose(opt.param.data.asnumpy(), param_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) - assert np.allclose(opt.m.data.asnumpy(), m_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) - assert np.allclose(opt.v.data.asnumpy(), v_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) + assert np.allclose(opt.param.data.asnumpy(), param_expect, + rtol=1.e-4, atol=1.e-8, equal_nan=True) + assert np.allclose(opt.m.data.asnumpy(), m_expect, + rtol=1.e-4, atol=1.e-8, equal_nan=True) + assert np.allclose(opt.v.data.asnumpy(), v_expect, + rtol=1.e-4, atol=1.e-8, equal_nan=True) def test_adam_weight_decay(): @@ -122,21 +178,57 @@ def test_adam_weight_decay(): beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr, param, m, v, is_weight_decay) - assert np.allclose(opt.param.data.asnumpy(), param_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) - assert np.allclose(opt.m.data.asnumpy(), m_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) - assert np.allclose(opt.v.data.asnumpy(), v_expect, rtol=1.e-4, atol=1.e-8, equal_nan=True) + assert np.allclose(opt.param.data.asnumpy(), param_expect, + rtol=1.e-4, atol=1.e-8, equal_nan=True) + assert np.allclose(opt.m.data.asnumpy(), m_expect, + rtol=1.e-4, atol=1.e-8, equal_nan=True) + assert np.allclose(opt.v.data.asnumpy(), v_expect, + rtol=1.e-4, atol=1.e-8, equal_nan=True) + + +def test_adam_side_effect(): + np.random.seed(0) + beta1 = np.array([0.9]).astype(np.float32) + beta2 = np.array([0.999]).astype(np.float32) + one_sub_beta_1 = (np.array([1.0]) - np.array([0.9])).astype(np.float32) + one_sub_beta_2 = (np.array([1.0]) - np.array([0.999])).astype(np.float32) + lr = np.array([0.012]).astype(np.float32) + eps = np.array([1e-6]).astype(np.float32) + weight_decay_tensor = np.array([0.021]).astype(np.float32) + + gradient = np.array([0.01, 0.03, 0.05]).astype(np.float32) + m = np.array([0.11, 0.33, 0.55]).astype(np.float32) + v = np.array([1.2, 3.4, 5.6]).astype(np.float32) + param = np.array([1, 3, 5]).astype(np.float32) + is_weight_decay = False + opt = SideEffectFusedAdamNet(is_weight_decay) + _ = opt(Tensor(beta1), Tensor(beta2), Tensor(one_sub_beta_1), Tensor(one_sub_beta_2), Tensor(gradient), Tensor(eps), + Tensor(weight_decay_tensor), Tensor(lr)) + param_expect, m_expect, v_expect = CalFusedAdam( + beta1, beta2, one_sub_beta_1, one_sub_beta_2, gradient, eps, weight_decay_tensor, lr, + param, m, v, is_weight_decay) + assert np.allclose(opt.param.data.asnumpy(), param_expect, + rtol=1.e-4, atol=1.e-8, equal_nan=True) + assert np.allclose(opt.m.data.asnumpy(), m_expect, + rtol=1.e-4, atol=1.e-8, equal_nan=True) + assert np.allclose(opt.v.data.asnumpy(), v_expect, + rtol=1.e-4, atol=1.e-8, equal_nan=True) + assert np.allclose(opt.x.data.asnumpy(), m_expect, + rtol=1.e-4, atol=1.e-8, equal_nan=True) @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_adam_gpu(): - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") + context.set_context(mode=context.GRAPH_MODE, + enable_graph_kernel=True, device_target="GPU") test_adam() def test_adam_ascend(): - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend") + context.set_context(mode=context.GRAPH_MODE, + enable_graph_kernel=True, device_target="Ascend") test_adam() @@ -144,10 +236,31 @@ def test_adam_ascend(): @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard def test_adam_weight_decay_gpu(): - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") + context.set_context(mode=context.GRAPH_MODE, + enable_graph_kernel=True, device_target="GPU") test_adam_weight_decay() def test_adam_weight_decay_ascend(): - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend") + context.set_context(mode=context.GRAPH_MODE, + enable_graph_kernel=True, device_target="Ascend") test_adam_weight_decay() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_adam_side_effect_gpu(): + context.set_context(mode=context.GRAPH_MODE, + enable_graph_kernel=True, device_target="GPU") + test_adam_side_effect() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_adam_side_effect_ascend(): + context.set_context(mode=context.GRAPH_MODE, + enable_graph_kernel=True, device_target="Ascend") + test_adam_side_effect() diff --git a/tests/st/ops/graph_kernel/test_lamb.py b/tests/st/ops/graph_kernel/test_lamb.py deleted file mode 100644 index be2f0620c1..0000000000 --- a/tests/st/ops/graph_kernel/test_lamb.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ - -import numpy as np -import mindspore.context as context -from mindspore import Tensor, Parameter -from mindspore.nn import Cell -from mindspore.nn._graph_kernels import LambUpdateWithLR, LambNextMV - - -class LambNet(Cell): - def __init__(self, i2, i5, x6): - super(LambNet, self).__init__() - self.i2 = Parameter(i2, name='i2') - self.i5 = Parameter(i5, name='i5') - self.x6 = Parameter(x6, name='x6') - self.lamb_next = LambNextMV() - self.lamb_update = LambUpdateWithLR() - - def construct(self, i1, i3, i4, i6, i7, i8, i9, ix0, ix1, ix2, ix3, - x1, x2, x3, x4, x5, gy, se, my): - i1_ = i1 + i3 - return self.lamb_next(i1_, self.i2, i3, i4, self.i5, i6, i7, i8, i9, ix0, - ix1, ix2, ix3), \ - self.lamb_update(x1, x2, x3, x4, x5, self.x6, gy, se, my) - - -def LambUpdateNumpy(x1, x2, x3, x4, x5, x6, gy, se, my): - trust_ratio = np.where(np.greater(x2, gy), - np.where(np.greater(x1, gy), np.divide(x2, x3), se), - se) - trust_ratio = np.maximum(np.minimum(trust_ratio, my), gy) - update_with_lr = trust_ratio * x4 * x5 - next_param = x6 - np.reshape(update_with_lr, x6.shape) - return next_param - - -def LambNextMVNumpy(i1, i2, i3, i4, i5, i6, i7, i8, i9, x0, x1, x2, x3): - m_fp32 = i5.astype(np.float32) - v_fp32 = i2.astype(np.float32) - next_m = i8 * m_fp32 + i9 * i4 - next_v = x0 * v_fp32 + x1 * i1 - next_mm = next_m / i6 - next_vv = next_v / i3 - update = next_mm / (np.sqrt(next_vv) + x3) - add3 = next_mm / np.sqrt(next_vv + x3) + x2 * i7 - return add3, next_m, next_v, update - - -def tensor_all(*args): - res = [Tensor(a) for a in args] - return res - - -def test_graph_kernel_lamb(): - shape = [1, 16] - oshape = [1] - np.random.seed(0) - x1 = np.random.normal(0, 1, oshape).astype(np.float32) - x2 = np.random.normal(0, 1, oshape).astype(np.float32) - x3 = np.random.normal(0, 1, oshape).astype(np.float32) - x4 = np.random.normal(0, 1, oshape).astype(np.float32) - x5 = np.random.normal(0, 1, shape).astype(np.float32) - x6 = np.random.normal(0, 1, shape).astype(np.float32) - gy = np.random.normal(0, 1, oshape).astype(np.float32) - se = np.random.normal(0, 1, oshape).astype(np.float32) - my = np.random.normal(0, 1, oshape).astype(np.float32) - - tx1, tx2, tx3, tx4, tx5, tx6, tgy, tse, tmy = tensor_all( - x1, x2, x3, x4, x5, x6, gy, se, my) - - np.random.seed(1) - i1 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) - i2 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) - i3 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) - i4 = np.random.normal(0, 1, shape).astype(np.float32) - i5 = np.random.normal(0, 1, shape).astype(np.float32) - i6 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) - i7 = np.random.normal(0, 1, shape).astype(np.float32) - i8 = np.random.normal(0, 1, shape).astype(np.float32) - i9 = np.random.normal(0, 1, shape).astype(np.float32) - ix0 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) - ix1 = np.abs(np.random.normal(0, 1, shape)).astype(np.float32) - ix2 = np.random.normal(0, 1, shape).astype(np.float32) - ix3 = np.ones(shape).astype(np.float32) * 1e-6 - - ti1, ti2, ti3, ti4, ti5, ti6, ti7, ti8, ti9, tix0, tix1, tix2, tix3 = \ - tensor_all(i1, i2, i3, i4, i5, i6, i7, i8, i9, ix0, ix1, ix2, ix3) - - context.set_context(enable_graph_kernel=True) - - net = LambNet(ti2, ti5, tx6) - (wa3, wup), _ = net(ti1, ti3, ti4, ti6, ti7, ti8, ti9, tix0, tix1, tix2, tix3, - tx1, tx2, tx3, tx4, tx5, tgy, tse, tmy) - - wi2 = net.i2.data.asnumpy().copy() - wi5 = net.i5.data.asnumpy().copy() - ares = net.x6.data.asnumpy().copy() - - context.set_context(enable_graph_kernel=False) - - i1_ = i1 + i3 - a3, a0, a1, up = LambNextMVNumpy(i1_, i2, i3, i4, i5, i6, i7, i8, i9, ix0, - ix1, ix2, ix3) - - np_res = LambUpdateNumpy(x1, x2, x3, x4, x5, x6, gy, se, my) - - rtol = 0.0001 - atol = 0.0001 - - wres = (wa3.asnumpy().copy(), wi5, wi2, wup.asnumpy().copy()) - bres = (a3, a0, a1, up) - - cmp_res = list(map(lambda x, y: np.allclose(x, y, rtol, atol), - wres, bres)) - - assert all(cmp_res) and np.allclose(ares, np_res, rtol, atol) - - -def test_graph_kernel_lamb_gpu(): - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU") - test_graph_kernel_lamb() - - -def test_graph_kernel_lamb_ascend(): - context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="Ascend") - test_graph_kernel_lamb() diff --git a/tests/ut/cpp/CMakeLists.txt b/tests/ut/cpp/CMakeLists.txt index f10a4a3f75..1996dd955c 100644 --- a/tests/ut/cpp/CMakeLists.txt +++ b/tests/ut/cpp/CMakeLists.txt @@ -121,6 +121,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "../../../mindspore/ccsrc/backend/optimizer/graph_kernel/*.cc" "../../../mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc" "../../../mindspore/ccsrc/backend/session/ascend_session.cc" + "../../../mindspore/ccsrc/backend/session/ascend_auto_monad.cc" "../../../mindspore/ccsrc/backend/session/ascend_control_parser.cc" "../../../mindspore/ccsrc/backend/session/kernel_graph.cc" "../../../mindspore/ccsrc/backend/session/session_basic.cc" @@ -145,6 +146,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/gpu/gp list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_fusion.cc") +list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/post_batch_norm_add_relu_fusion.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_add_relu_grad_fusion.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_fusion.cc") list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc") diff --git a/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc b/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc index 7178ee6679..e42973e694 100644 --- a/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/pass/allreduce_fusion_test.cc @@ -214,7 +214,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_sorted) { FuncGraphPtr new_graph = optimizer->Optimize(func_graph); EXPECT_NE(new_graph, nullptr); // check result - FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_all_reduce_fusion_all", "after1"); + FuncGraphPtr g_after = getPyFun_.CallAndParseRet("test_all_reduce_fusion_all", "after"); EXPECT_NE(g_after, nullptr); EXPECT_TRUE(CheckEqualGraph(new_graph, g_after)); } diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py index 8223dbf59e..79ec70ca69 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/fused_batch_norm_fusion_test.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ import mindspore.common.dtype as mstype +from mindspore.common import monad from mindspore.common.tensor import Tensor from mindspore.ops import Primitive from mindspore.ops import operations as P @@ -24,6 +25,8 @@ Mul = P.Mul() Sub = P.Sub() make_tuple = Primitive('make_tuple') tuple_getitem = Primitive(Constants.kTupleGetItem) +update_state = Primitive('UpdateState') +U = monad.U BatchNorm = P.BatchNorm() Cast = P.Cast() BNTrainingReduce = Primitive('BNTrainingReduce') @@ -53,11 +56,13 @@ def test_fused_batch_norm_fusion(tag): sub1 = Sub(var1, tuple_getitem(batch_norm, 2)) mul0 = Mul(sub0, constant0) mul1 = Mul(sub1, constant1) - assign_sub0 = AssignSub(var0, mul0) - assign_sub1 = AssignSub(var1, mul1) + assign_sub0 = AssignSub(var0, mul0, U) + u0 = update_state(U, assign_sub0) + assign_sub1 = AssignSub(var1, mul1, u0) + u1 = update_state(u0, assign_sub1) depend0 = F.depend(tuple_getitem(batch_norm, 0), assign_sub0) depend1 = F.depend(depend0, assign_sub1) - outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4)) + outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4), u1) output = tuple_getitem(outputs, 0) return output @@ -68,11 +73,13 @@ def test_fused_batch_norm_fusion(tag): sub1 = Sub(Cast(var1, mstype.float32), tuple_getitem(batch_norm, 2)) mul0 = Mul(sub0, constant0) mul1 = Mul(sub1, constant1) - assign_sub0 = AssignSub(var0, Cast(mul0, mstype.float32)) - assign_sub1 = AssignSub(var1, Cast(mul1, mstype.float32)) + assign_sub0 = AssignSub(var0, Cast(mul0, mstype.float32), U) + u0 = update_state(U, assign_sub0) + assign_sub1 = AssignSub(var1, Cast(mul1, mstype.float32), u0) + u1 = update_state(u0, assign_sub1) depend0 = F.depend(tuple_getitem(batch_norm, 0), assign_sub0) depend1 = F.depend(depend0, assign_sub1) - outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4)) + outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4), u1) output = tuple_getitem(outputs, 0) return output @@ -83,11 +90,13 @@ def test_fused_batch_norm_fusion(tag): sub1 = Sub(Cast(var1, mstype.float32), tuple_getitem(batch_norm, 2)) mul0 = Mul(Cast(sub0, mstype.float32), constant0) mul1 = Mul(Cast(sub1, mstype.float32), constant1) - assign_sub0 = AssignSub(var0, mul0) - assign_sub1 = AssignSub(var1, mul1) + assign_sub0 = AssignSub(var0, mul0, U) + u0 = update_state(U, assign_sub0) + assign_sub1 = AssignSub(var1, mul1, u0) + u1 = update_state(u0, assign_sub1) depend0 = F.depend(tuple_getitem(batch_norm, 0), assign_sub0) depend1 = F.depend(depend0, assign_sub1) - outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4)) + outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4), u1) output = tuple_getitem(outputs, 0) return output @@ -97,7 +106,7 @@ def test_fused_batch_norm_fusion(tag): bn_training_update = BNTrainingUpdate(input0, tuple_getitem(bn_training_reduce, 0), tuple_getitem(bn_training_reduce, 1), input1, input2, var0, var1) outputs = make_tuple(tuple_getitem(bn_training_update, 0), tuple_getitem(bn_training_update, 3), - tuple_getitem(bn_training_update, 4)) + tuple_getitem(bn_training_update, 4), U) output = tuple_getitem(outputs, 0) return make_tuple(output) diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py index 9f7abbaa64..34303dad93 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/insert_memcpy_async_for_hccl_op.py @@ -17,13 +17,14 @@ from mindspore.ops import Primitive from mindspore.ops import operations as P from mindspore.ops import _constants as Constants +depend = P.Depend() all_reduce = P.AllReduce() broadcast = P.Broadcast(1) memcpy_async = Primitive('memcpy_async') make_tuple = Primitive('make_tuple') tuple_getitem = Primitive(Constants.kTupleGetItem) assign_add = P.AssignAdd() -control_depend = P.ControlDepend() +apply_momentun = P.ApplyMomentum() relu = P.ReLU() @@ -106,7 +107,7 @@ def test_insert_memcpy_async_for_hccl_op_cond4(tag): def before(a, b): x = relu(a) y = all_reduce(b) - res = control_depend(x, y) + res = depend(x, y) return res @fns @@ -114,7 +115,7 @@ def test_insert_memcpy_async_for_hccl_op_cond4(tag): x = relu(a) y1 = memcpy_async(b) y2 = all_reduce(y1) - res = control_depend(x, make_tuple(y1, y2)) + res = depend(x, y2) return make_tuple(res) return fns[tag] @@ -127,7 +128,7 @@ def test_insert_memcpy_async_for_hccl_op_cond5(tag): def before(a, b, c): x = relu(a) y = broadcast((b, c)) - res = control_depend(x, y) + res = depend(x, y) return res @fns @@ -136,7 +137,9 @@ def test_insert_memcpy_async_for_hccl_op_cond5(tag): m1 = memcpy_async(b) m2 = memcpy_async(c) y = broadcast(m1, m2) - res = control_depend(x, make_tuple(m1, m2, y)) + y0 = tuple_getitem(y, 0) + y1 = tuple_getitem(y, 1) + res = depend(x, make_tuple(y0, y1)) return make_tuple(res) return fns[tag] diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/ir_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/ir_fusion_test.py index bec8c08354..fe323432c4 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/ir_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/ir_fusion_test.py @@ -132,17 +132,6 @@ def test_all_reduce_fusion_all(tag): @fns def after(x1, x2, x3, x4, x5): - ar = allreduce(x5, x4, x3, x2, x1) - y5 = tuple_getitem(ar, 0) - y4 = tuple_getitem(ar, 1) - y3 = tuple_getitem(ar, 2) - y2 = tuple_getitem(ar, 3) - y1 = tuple_getitem(ar, 4) - res = make_tuple(y1, y2, y3, y4, y5) - return make_tuple(res) - - @fns - def after1(x1, x2, x3, x4, x5): ar = allreduce(x1, x2, x3, x4, x5) y1 = tuple_getitem(ar, 0) y2 = tuple_getitem(ar, 1) @@ -170,13 +159,13 @@ def test_all_reduce_fusion_group(tag): @fns def after1(x1, x2, x3, x4, x5): - ar1 = allreduce(x5, x4) - ar2 = allreduce(x3, x2, x1) - y4 = tuple_getitem(ar1, 1) - y5 = tuple_getitem(ar1, 0) - y1 = tuple_getitem(ar2, 2) - y2 = tuple_getitem(ar2, 1) + ar1 = allreduce(x1, x2) + ar2 = allreduce(x3, x4, x5) + y1 = tuple_getitem(ar1, 0) + y2 = tuple_getitem(ar1, 1) y3 = tuple_getitem(ar2, 0) + y4 = tuple_getitem(ar2, 1) + y5 = tuple_getitem(ar2, 2) res = make_tuple(y1, y2, y3, y4, y5) return make_tuple(res) @@ -184,11 +173,11 @@ def test_all_reduce_fusion_group(tag): def after2(x1, x2, x3, x4, x5): ar1 = allreduce(x1, x3, x5) ar2 = allreduce(x2, x4) - y1 = tuple_getitem(ar1, 2) + y1 = tuple_getitem(ar1, 0) y3 = tuple_getitem(ar1, 1) - y5 = tuple_getitem(ar1, 0) - y2 = tuple_getitem(ar2, 1) - y4 = tuple_getitem(ar2, 0) + y5 = tuple_getitem(ar1, 2) + y2 = tuple_getitem(ar2, 0) + y4 = tuple_getitem(ar2, 1) output = make_tuple(y1, y2, y3, y4, y5) return make_tuple(output) diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/momentum_lossscale_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/momentum_lossscale_fusion_test.py index 160e7da73d..1eda961918 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/momentum_lossscale_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/momentum_lossscale_fusion_test.py @@ -13,10 +13,12 @@ # limitations under the License. # ============================================================================ import mindspore.common.dtype as mstype +from mindspore.common import monad from mindspore.common.tensor import Tensor from mindspore.ops import Primitive from mindspore.ops import operations as P from mindspore.ops import _constants as Constants +from mindspore.ops import functional as F Mul = P.Mul() ApplyMomentum = P.ApplyMomentum() @@ -43,11 +45,12 @@ def test_momentum_lossscale_fusion(tag): @fns def before(input0, input1, input2, input3, input4): mul = Mul(constant, input3) - fused_mul_apply_momentum = ApplyMomentum(input0, input1, input2, mul, input4) + fused_mul_apply_momentum = ApplyMomentum(input0, input1, input2, mul, input4, monad.U) return fused_mul_apply_momentum @fns def after(input0, input1, input2, input3, input4): - return make_tuple(tuple_getitem(FusedMulApplyMomentum(input0, input1, input2, input3, input4, constant), 0)) + dep = F.depend(input4, monad.U) + return make_tuple(tuple_getitem(FusedMulApplyMomentum(input0, input1, input2, input3, dep, constant), 0)) return fns[tag] diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/optimize_dependence_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/optimize_dependence_test.py index 2acaff2987..23ece9e67d 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/optimize_dependence_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/optimize_dependence_test.py @@ -16,7 +16,6 @@ from mindspore.ops import Primitive from mindspore.ops import operations as P depend = P.Depend() -controldepend = Primitive("ControlDepend") TransData = Primitive('TransData') add = P.Add() make_tuple = Primitive('make_tuple') @@ -78,13 +77,13 @@ def test_optimize_control_dependence(tag): @fns def before(x, y, z): new_z = TransData(z) - depend_intput = controldepend(y, new_z) + depend_intput = depend(y, new_z) sum_add = add(x, depend_intput) return sum_add @fns def after(x, y, z): - depend_intput = controldepend(y, z) + depend_intput = depend(y, z) sum_add = add(x, depend_intput) return sum_add @@ -97,14 +96,14 @@ def test_optimize_control_dependence_with_make_tuple(tag): @fns def before(x, y, a, b): z = make_tuple(TransData(a), TransData(b)) - depend_intput = controldepend(y, z) + depend_intput = depend(y, z) sum_add = add(x, depend_intput) return sum_add @fns def after(x, y, a, b): z = make_tuple(a, b) - depend_intput = controldepend(y, z) + depend_intput = depend(y, z) sum_add = add(x, depend_intput) return sum_add diff --git a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc index ab74aefd0e..d13c370a03 100644 --- a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc +++ b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc @@ -86,19 +86,13 @@ TEST_F(AnfRuntimeAlgorithmTest, VisitKernel) { EXPECT_NE(kernel_with_index.first->cast(), nullptr); EXPECT_EQ((kernel_with_index.first->cast()).get(), add_second.get()); EXPECT_EQ(kernel_with_index.second, 0); - // test depend or control depend node as input + // test depend node as input std::vector depend_inputs{NewValueNode(prim::kPrimDepend), add, add_second}; auto depend = kernel_graph->NewCNode(depend_inputs); kernel_with_index = AnfAlgo::VisitKernel(depend, 0); EXPECT_NE(kernel_with_index.first->cast(), nullptr); EXPECT_EQ((kernel_with_index.first->cast()).get(), add.get()); EXPECT_EQ(kernel_with_index.second, 0); - std::vector control_depend_inputs{NewValueNode(prim::kPrimControlDepend), add_second, add}; - auto control_depend = kernel_graph->NewCNode(control_depend_inputs); - kernel_with_index = AnfAlgo::VisitKernel(control_depend, 0); - EXPECT_NE(kernel_with_index.first->cast(), nullptr); - EXPECT_EQ((kernel_with_index.first->cast()).get(), add_second.get()); - EXPECT_EQ(kernel_with_index.second, 0); } TEST_F(AnfRuntimeAlgorithmTest, GetCNodePrimitive) { diff --git a/tests/ut/python/ir/test_row_tensor.py b/tests/ut/python/ir/test_row_tensor.py index 8d21d8550f..b2d2b2f8b8 100644 --- a/tests/ut/python/ir/test_row_tensor.py +++ b/tests/ut/python/ir/test_row_tensor.py @@ -40,7 +40,11 @@ from mindspore.nn.optim import Momentum from mindspore.train import Model from ....dataset_mock import MindData -context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) +@pytest.fixture(scope="module", autouse=True) +def setup_teardown(): + context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) + yield + context.set_context(enable_sparse=False) reduce_sum = P.ReduceSum() unsorted_segment_sum = P.UnsortedSegmentSum() diff --git a/tests/ut/python/ir/test_sparse_tensor.py b/tests/ut/python/ir/test_sparse_tensor.py index 184bc26d93..bd84640ec2 100644 --- a/tests/ut/python/ir/test_sparse_tensor.py +++ b/tests/ut/python/ir/test_sparse_tensor.py @@ -26,7 +26,12 @@ import mindspore.nn as nn from mindspore.ops import composite as C from mindspore import Tensor, SparseTensor, context -context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) +@pytest.fixture(scope="module", autouse=True) +def setup_teardown(): + context.set_context(mode=context.GRAPH_MODE, enable_sparse=True) + yield + context.set_context(enable_sparse=False) + grad_op = C.GradOperation(get_all=True) diff --git a/tests/ut/python/keep_order/test_keep_order.py b/tests/ut/python/keep_order/test_keep_order.py index b2a6d4d144..044d6e0bd7 100644 --- a/tests/ut/python/keep_order/test_keep_order.py +++ b/tests/ut/python/keep_order/test_keep_order.py @@ -43,9 +43,14 @@ class Func(nn.Cell): init = self.alloc_status() sum_ = add(x, y) product = mul1(x, y) - flag = self.get_status(init) + init = F.depend(init, sum_) + init = F.depend(init, product) + get_status = self.get_status(init) + sum_ = F.depend(sum_, get_status) + product = F.depend(product, get_status) out = add2(sum_, product) - clear = self.clear_status(flag) + init = F.depend(init, out) + clear = self.clear_status(init) out = F.depend(out, clear) return out @@ -65,11 +70,16 @@ class Net(nn.Cell): init = self.alloc_status() sum1 = add(x, y) dx = grad_s(self.func)(x, y, sens) - flag = self.get_status(init) + init = F.depend(init, sum1) + init = F.depend(init, dx) + get_status = self.get_status(init) + sum1 = F.depend(sum1, get_status) + dx = F.depend(dx, get_status) sum2 = add2(sum1, dx[0]) sum3 = add2(y, dx[1]) out = add2(sum2, sum3) - clear = self.clear_status(flag) + init = F.depend(init, out) + clear = self.clear_status(init) out = F.depend(out, clear) return out @@ -78,7 +88,6 @@ def test_add(): x = Tensor(np.ones([3, 3]).astype(np.float32)) y = Tensor(np.ones([3, 3]).astype(np.float32)) func = Func() - func.add_flags(has_effect=True) func(x, y) @@ -87,7 +96,6 @@ def test_sens(): y = Tensor(np.ones([3, 3]).astype(np.float32)) sens = Tensor(np.ones([3, 3]).astype(np.float32)) net = Net() - net.add_flags(has_effect=True) _ = net(x, y, sens) @@ -104,11 +112,16 @@ class Net_hyper(nn.Cell): add1 = add(x, y) sum1 = C.hyper_add([add1, add1], [x, y]) dx = grad_s(self.func)(x, y, sens) - flag = self.get_status(init) + init = F.depend(init, sum1) + init = F.depend(init, dx) + get_status = self.get_status(init) + sum1 = F.depend(sum1, get_status) + dx = F.depend(dx, get_status) sum2 = add2(sum1[0], dx[0]) sum3 = add2(sum1[1], dx[1]) out = C.hyper_add([sum2, sum2], [sum3, sum3]) - clear = self.clear_status(flag) + init = F.depend(init, out) + clear = self.clear_status(init) out = F.depend(out, clear) return out @@ -118,7 +131,6 @@ def test_hyper_add(): y = Tensor(np.ones([3, 3]).astype(np.float32)) sens = Tensor(np.ones([3, 3]).astype(np.float32)) net = Net_hyper() - net.add_flags(has_effect=True) _ = net(x, y, sens) @@ -134,12 +146,14 @@ def test_keep_order_io_effect_exception_return_dtype(): self.sub = P.Sub() self.neg = P.Neg() - @C.add_flags(has_effect=True) def construct(self, x): init = self.alloc_status() - self.clear_status(init) + clear_status = self.clear_status(init) + x = F.depend(x, clear_status) res = self.sub(x, self.neg(x)) - self.get_status(init) + init = F.depend(init, res) + get_status = self.get_status(init) + res = F.depend(res, get_status) dtype = self.dtype(res) return dtype diff --git a/tests/ut/python/nn/optim/test_ada_grad.py b/tests/ut/python/nn/optim/test_ada_grad.py index d2fe45773e..fb27a294e5 100644 --- a/tests/ut/python/nn/optim/test_ada_grad.py +++ b/tests/ut/python/nn/optim/test_ada_grad.py @@ -14,6 +14,7 @@ # ============================================================================ """ test ADA_GRAD """ +import pytest import numpy as np import mindspore.nn as nn @@ -23,7 +24,12 @@ from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Adagrad from mindspore.ops import operations as P -context.set_context(enable_sparse=True) + +@pytest.fixture(scope="module", autouse=True) +def setup_teardown(): + context.set_context(enable_sparse=True) + yield + context.set_context(enable_sparse=False) class Net(nn.Cell): diff --git a/tests/ut/python/nn/optim/test_adam.py b/tests/ut/python/nn/optim/test_adam.py index 774cede36d..43b459b8b4 100644 --- a/tests/ut/python/nn/optim/test_adam.py +++ b/tests/ut/python/nn/optim/test_adam.py @@ -23,7 +23,11 @@ from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Adam, AdamWeightDecay from mindspore.ops import operations as P -context.set_context(enable_sparse=True) +@pytest.fixture(scope="module", autouse=True) +def setup_teardown(): + context.set_context(enable_sparse=True) + yield + context.set_context(enable_sparse=False) class Net(nn.Cell): """ Net definition """ diff --git a/tests/ut/python/nn/optim/test_ftrl.py b/tests/ut/python/nn/optim/test_ftrl.py index 2b5f1ef481..a2d8c2efb3 100644 --- a/tests/ut/python/nn/optim/test_ftrl.py +++ b/tests/ut/python/nn/optim/test_ftrl.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """ test FTRL """ - +import pytest import numpy as np import mindspore.nn as nn @@ -23,7 +23,12 @@ from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import FTRL from mindspore.ops import operations as P -context.set_context(enable_sparse=True) +@pytest.fixture(scope="module", autouse=True) +def setup_teardown(): + context.set_context(enable_sparse=True) + yield + context.set_context(enable_sparse=False) + class Net(nn.Cell): def __init__(self): diff --git a/tests/ut/python/nn/optim/test_lazyadam.py b/tests/ut/python/nn/optim/test_lazyadam.py index f97b22f9da..2d52003d28 100644 --- a/tests/ut/python/nn/optim/test_lazyadam.py +++ b/tests/ut/python/nn/optim/test_lazyadam.py @@ -23,7 +23,12 @@ from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import LazyAdam from mindspore.ops import operations as P -context.set_context(enable_sparse=True) +@pytest.fixture(scope="module", autouse=True) +def setup_teardown(): + context.set_context(enable_sparse=True) + yield + context.set_context(enable_sparse=False) + class Net(nn.Cell): """ Net definition """ diff --git a/tests/ut/python/nn/optim/test_proximal_ada_grad.py b/tests/ut/python/nn/optim/test_proximal_ada_grad.py index d8a8198072..674cddae25 100644 --- a/tests/ut/python/nn/optim/test_proximal_ada_grad.py +++ b/tests/ut/python/nn/optim/test_proximal_ada_grad.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """ test PROXIMAL_ADA_GRAD """ - +import pytest import numpy as np import mindspore.nn as nn @@ -23,7 +23,12 @@ from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import ProximalAdagrad from mindspore.ops import operations as P -context.set_context(enable_sparse=True) +@pytest.fixture(scope="module", autouse=True) +def setup_teardown(): + context.set_context(enable_sparse=True) + yield + context.set_context(enable_sparse=False) + class Net(nn.Cell): def __init__(self): diff --git a/tests/ut/python/nn/optim/test_target.py b/tests/ut/python/nn/optim/test_target.py index f518fedc6d..8bd51d440f 100644 --- a/tests/ut/python/nn/optim/test_target.py +++ b/tests/ut/python/nn/optim/test_target.py @@ -13,13 +13,18 @@ # limitations under the License. # ============================================================================ """ test lazy adam """ +import pytest import numpy as np from mindspore.nn.optim import LazyAdam, FTRL, Adam, ProximalAdagrad import mindspore.nn as nn from mindspore import Tensor, Parameter, context from mindspore.ops import operations as P -context.set_context(enable_sparse=True) +@pytest.fixture(scope="module", autouse=True) +def setup_teardown(): + context.set_context(enable_sparse=True) + yield + context.set_context(enable_sparse=False) class NetWithSparseGatherV2(nn.Cell): diff --git a/tests/ut/python/nn/test_nn_embedding.py b/tests/ut/python/nn/test_nn_embedding.py index 5eb4964840..8a8dcef841 100755 --- a/tests/ut/python/nn/test_nn_embedding.py +++ b/tests/ut/python/nn/test_nn_embedding.py @@ -82,14 +82,14 @@ def test_check_multifield_embedding_false_type_field_id(): @non_graph_engine def test_check_multifield_embedding_false_input_shape(): - with pytest.raises(TypeError): + with pytest.raises(ValueError): compile_multi_field_embedding((8,), (8, 200), (8, 200), dtype.int16, dtype.float32, dtype.int16) @non_graph_engine def test_check_multifield_embedding_false_value_shape(): - with pytest.raises(TypeError): + with pytest.raises(ValueError): compile_multi_field_embedding((8, 200), (8,), (8, 200), dtype.int16, dtype.float32, dtype.int16) diff --git a/tests/ut/python/nn/test_ssim.py b/tests/ut/python/nn/test_ssim.py index 8b7e441014..8e6866cb59 100644 --- a/tests/ut/python/nn/test_ssim.py +++ b/tests/ut/python/nn/test_ssim.py @@ -84,7 +84,7 @@ def test_ssim_different_shape(): img1 = Tensor(np.random.random(shape_1)) img2 = Tensor(np.random.random(shape_2)) net = SSIMNet() - with pytest.raises(ValueError): + with pytest.raises(TypeError): _executor.compile(net, img1, img2) @@ -108,9 +108,9 @@ def test_ssim_invalid_5d_input(): invalid_img2 = Tensor(np.random.random(invalid_shape)) net = SSIMNet() - with pytest.raises(ValueError): + with pytest.raises(TypeError): _executor.compile(net, invalid_img1, img2) - with pytest.raises(ValueError): + with pytest.raises(TypeError): _executor.compile(net, img1, invalid_img2) - with pytest.raises(ValueError): + with pytest.raises(TypeError): _executor.compile(net, invalid_img1, invalid_img2) diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index a7d96a7772..03e30b719a 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -613,18 +613,6 @@ def test_switch_layer_single_layer(): net2(x, i) -def test_control_depend_check(): - with pytest.raises(TypeError) as e: - P.ControlDepend(0.0) - print(e) - with pytest.raises(ValueError) as e: - P.ControlDepend(2) - print(e) - with pytest.raises(TypeError) as e: - P.ControlDepend((2,)) - print(e) - - def test_if_nested_compile(): class Net(nn.Cell): def __init__(self, auto_prefix=True): diff --git a/tests/ut/python/ops/test_math_ops.py b/tests/ut/python/ops/test_math_ops.py index ac4b1ac434..a2d7a1c233 100755 --- a/tests/ut/python/ops/test_math_ops.py +++ b/tests/ut/python/ops/test_math_ops.py @@ -24,6 +24,7 @@ from mindspore import Tensor from mindspore.common import dtype as mstype from mindspore.ops import composite as C from mindspore.ops import operations as P +from mindspore.ops import functional as F from mindspore.ops import prim_attr_register, PrimitiveWithInfer from ..ut_filter import non_graph_engine from ....mindspore_test_framework.mindspore_test import mindspore_test @@ -272,12 +273,14 @@ class NpuFloatNet(nn.Cell): self.sub = P.Sub() self.neg = P.Neg() - @C.add_flags(has_effect=True) def construct(self, x): init = self.alloc_status() - self.clear_status(init) + clear_status = self.clear_status(init) + x = F.depend(x, clear_status) # let x depend on clear_status res = self.sub(x, self.neg(x)) - self.get_status(init) + init = F.depend(init, res) # let get_status depend on res + get_status = self.get_status(init) + init = F.depend(init, get_status) # let reduce_sum depend on get_statusk flag_sum = self.reduce_sum(init, (0,)) base = self.cast(self.fill(self.dtype(res), self.shape_op(res), 0.0), self.dtype(flag_sum)) cond = self.less(base, flag_sum) diff --git a/tests/ut/python/parallel/test_allreduce_fusion.py b/tests/ut/python/parallel/test_allreduce_fusion.py index 659560ac97..b7147945a9 100644 --- a/tests/ut/python/parallel/test_allreduce_fusion.py +++ b/tests/ut/python/parallel/test_allreduce_fusion.py @@ -13,6 +13,7 @@ # limitations under the License. import numpy as np +import pytest import mindspore as ms import mindspore.nn as nn @@ -25,7 +26,6 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.train import Model from mindspore.context import ParallelMode from tests.dataset_mock import MindData -import pytest class Dataset(MindData): diff --git a/tests/ut/python/parallel/test_loss_scale.py b/tests/ut/python/parallel/test_loss_scale.py index 9649679474..f8497de2cd 100644 --- a/tests/ut/python/parallel/test_loss_scale.py +++ b/tests/ut/python/parallel/test_loss_scale.py @@ -69,9 +69,8 @@ class TrainOneStepWithLossScaleCell(nn.Cell): self.cast = P.Cast() self.alloc_status = P.NPUAllocFloatStatus() self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() + self.clear_status = P.NPUClearFloatStatus() self.reduce_sum = P.ReduceSum(keep_dims=False) - self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.base = Tensor(1, mstype.float32) self.less_equal = P.LessEqual() self.hyper_map = C.HyperMap() @@ -81,7 +80,6 @@ class TrainOneStepWithLossScaleCell(nn.Cell): self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32), name="loss_scale") - @C.add_flags(has_effect=True) def construct(self, x, sens=None): """Defines the computation performed.""" weights = self.weights @@ -92,12 +90,16 @@ class TrainOneStepWithLossScaleCell(nn.Cell): scaling_sens = sens # alloc status and clear should be right before gradoperation init = self.alloc_status() - self.clear_before_grad(init) + init = F.depend(init, loss) + clear_status = self.clear_status(init) + scaling_sens = F.depend(scaling_sens, clear_status) grads = self.grad(self.network, weights)(x, self.cast(scaling_sens, mstype.float32)) # apply grad reducer on grads grads = self.grad_reducer(grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) - self.get_status(init) + init = F.depend(init, grads) + get_status = self.get_status(init) + init = F.depend(init, get_status) flag_sum = self.reduce_sum(init, (0,)) cond = self.less_equal(self.base, flag_sum) overflow = cond diff --git a/tests/ut/python/parallel/test_manual_embedding_lookup.py b/tests/ut/python/parallel/test_manual_embedding_lookup.py index 542348946f..8e210f95f7 100644 --- a/tests/ut/python/parallel/test_manual_embedding_lookup.py +++ b/tests/ut/python/parallel/test_manual_embedding_lookup.py @@ -22,7 +22,11 @@ from mindspore.nn import Cell, TrainOneStepCell, LazyAdam from mindspore.ops import operations as P from mindspore.common.initializer import initializer -context.set_context(enable_sparse=True) +@pytest.fixture(scope="module", autouse=True) +def setup_teardown(): + context.set_context(enable_sparse=True) + yield + context.set_context(enable_sparse=False) class Net(Cell): diff --git a/tests/ut/python/parallel/test_multi_field_embedding.py b/tests/ut/python/parallel/test_multi_field_embedding.py index a30a1a7c54..f858caf262 100644 --- a/tests/ut/python/parallel/test_multi_field_embedding.py +++ b/tests/ut/python/parallel/test_multi_field_embedding.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import pytest import numpy as np import mindspore as ms @@ -26,6 +27,14 @@ from tests.ut.python.ops.test_math_ops import VirtualLoss grad_all = C.GradOperation(get_all=True) +@pytest.fixture(name="test_context") +def _test_context(): + context.set_context(enable_sparse=True) + yield + context.set_context(enable_sparse=False) + context.reset_auto_parallel_context() + + class GradWrap(nn.Cell): def __init__(self, network): super(GradWrap, self).__init__() @@ -52,7 +61,7 @@ class Net(nn.Cell): super().__init__() self.embedding = nn.MultiFieldEmbeddingLookup(vocab_size=32, embedding_size=64, target=target, field_size=field_size, slice_mode=slice_mode, operator=operator) - self.reshape = P.Reshape().shard(((8, 1, 1),)) + self.reshape = P.Reshape() self.batch_size = shape[0] def construct(self, x, y, z): @@ -62,7 +71,6 @@ class Net(nn.Cell): def compile_net(net, shape): - context.set_context(enable_sparse=True) x = Tensor(np.ones(shape), dtype=ms.int32) y = Tensor(np.ones(shape), dtype=ms.float32) z = Tensor(np.ones(shape), dtype=ms.int32) @@ -74,63 +82,63 @@ def compile_net(net, shape): context.reset_auto_parallel_context() -def test_embeddinglookup_batch_parallel_sum(): +def test_embeddinglookup_batch_parallel_sum(test_context): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") shape = [64, 64] net = NetWithLoss(Net(shape, field_size=10, target='DEVICE')) compile_net(net, shape) -def test_embeddinglookup_row_parallel_sum(): +def test_embeddinglookup_row_parallel_sum(test_context): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") shape = [64, 64] net = NetWithLoss(Net(shape, field_size=9, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, target='DEVICE')) compile_net(net, shape) -def test_embeddinglookup_column_parallel_sum(): +def test_embeddinglookup_column_parallel_sum(test_context): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") shape = [64, 64] net = NetWithLoss(Net(shape, field_size=10, slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE, target='DEVICE')) compile_net(net, shape) -def test_embeddinglookup_batch_parallel_mean(): +def test_embeddinglookup_batch_parallel_mean(test_context): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") shape = [64, 64] net = NetWithLoss(Net(shape, field_size=1, target='DEVICE', operator='MEAN')) compile_net(net, shape) -def test_embeddinglookup_column_parallel_mean(): +def test_embeddinglookup_column_parallel_mean(test_context): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") shape = [64, 64] net = NetWithLoss(Net(shape, target='DEVICE', slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE, operator='MEAN')) compile_net(net, shape) -def test_embeddinglookup_row_parallel_mean(): +def test_embeddinglookup_row_parallel_mean(test_context): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") shape = [64, 64] net = NetWithLoss(Net(shape, target='DEVICE', slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, operator='MEAN')) compile_net(net, shape) -def test_embeddinglookup_batch_parallel_max(): +def test_embeddinglookup_batch_parallel_max(test_context): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") shape = [64, 64] net = NetWithLoss(Net(shape, target='DEVICE', operator='MAX')) compile_net(net, shape) -def test_embeddinglookup_column_parallel_max(): +def test_embeddinglookup_column_parallel_max(test_context): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") shape = [64, 64] net = NetWithLoss(Net(shape, target='DEVICE', slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE, operator='MAX')) compile_net(net, shape) -def test_embeddinglookup_row_parallel_max(): +def test_embeddinglookup_row_parallel_max(test_context): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") shape = [64, 64] net = NetWithLoss(Net(shape, target='DEVICE', slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE, operator='MAX')) diff --git a/tests/ut/python/parallel/test_sparse_feature_bprop.py b/tests/ut/python/parallel/test_sparse_feature_bprop.py index 1ba968d62b..549a64d71a 100644 --- a/tests/ut/python/parallel/test_sparse_feature_bprop.py +++ b/tests/ut/python/parallel/test_sparse_feature_bprop.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """ test sparse feature bprop """ +import pytest import numpy as np import mindspore as ms @@ -29,6 +30,14 @@ from mindspore.nn import TrainOneStepCell, Adam grad_all = C.GradOperation(get_all=True) +@pytest.fixture(name="test_context") +def _test_context(): + context.set_context(enable_sparse=True) + yield + context.set_context(enable_sparse=False) + context.reset_auto_parallel_context() + + class GradWrap(nn.Cell): def __init__(self, network): super(GradWrap, self).__init__() @@ -37,9 +46,8 @@ class GradWrap(nn.Cell): def construct(self, x): return grad_all(self.network)(x) -def test_bprop_with_sparse_feature_allreduce(): +def test_bprop_with_sparse_feature_allreduce(test_context): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="hybrid_parallel") - context.set_context(enable_sparse=True) class Net(nn.Cell): def __init__(self, axis=0, shape=None): @@ -64,9 +72,8 @@ def test_bprop_with_sparse_feature_allreduce(): _executor.compile(net, x) -def test_bprop_with_sparse_feature_mirror(): +def test_bprop_with_sparse_feature_mirror(test_context): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") - context.set_context(enable_sparse=True) class Net(nn.Cell): def __init__(self, shape=None): @@ -95,9 +102,8 @@ def test_bprop_with_sparse_feature_mirror(): compile_net(net) -def test_bprop_with_sparse_feature_dataparallel(): +def test_bprop_with_sparse_feature_dataparallel(test_context): context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="data_parallel") - context.set_context(enable_sparse=True) class Net(nn.Cell): def __init__(self, axis=0, shape=None): diff --git a/tests/ut/python/pipeline/infer/test_auto_monad.py b/tests/ut/python/pipeline/infer/test_auto_monad.py new file mode 100644 index 0000000000..5be520525a --- /dev/null +++ b/tests/ut/python/pipeline/infer/test_auto_monad.py @@ -0,0 +1,348 @@ +import numpy as np +import pytest + +import mindspore as ms +import mindspore.ops.composite as C +from mindspore import context +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore import Tensor +from mindspore.common.parameter import Parameter, ParameterTuple + +grad_all_list = C.GradOperation(get_all=True, get_by_list=True) +grad_by_list = C.GradOperation(get_by_list=True) + +context.set_context(mode=context.GRAPH_MODE, save_graphs=False) + + +def test_load_grad(): + class LoadNet(nn.Cell): + def __init__(self): + super().__init__() + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + + def construct(self, x, y): + x = x * y * self.z + return x + + x = Tensor(np.array([2.0], np.float32)) + y = Tensor(np.array([3.0], np.float32)) + load_net = LoadNet() + grad_net = grad_all_list( + load_net, ParameterTuple(load_net.trainable_params())) + print(grad_net(x, y)) + + +def test_assign_only_grad(): + class AssignOnlyNet(nn.Cell): + def __init__(self): + super().__init__() + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + + def construct(self, x, y): + self.z = x + x = x * y + return x + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.parameter_tuple = ParameterTuple(self.trainable_params()) + + def construct(self, x, y): + return grad_all_list(self.net, self.parameter_tuple)(x, y) + + assign_net = AssignOnlyNet() + net = GradNet(assign_net) + x = Tensor(np.array([2.0], np.float32)) + y = Tensor(np.array([3.0], np.float32)) + print(net(x, y)) + + +def test_load_assign_grad(): + class AssignNet(nn.Cell): + def __init__(self): + super().__init__() + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + self.assign = P.Assign() + + def construct(self, x, y): + x = x * self.z + self.assign(self.z, x) + out = y * self.z + return out + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.parameter_tuple = ParameterTuple(net.trainable_params()) + + def construct(self, x, y): + return grad_all_list(self.net, self.parameter_tuple)(x, y) + + assign_net = AssignNet() + net = GradNet(assign_net) + x = Tensor(np.array([2.0], np.float32)) + y = Tensor(np.array([3.0], np.float32)) + print(net(x, y)) + + +def test_insert_gradient_of(): + class InsertGradientNet(nn.Cell): + def __init__(self): + super(InsertGradientNet, self).__init__() + self.gather = P.GatherV2() + self.damping = Tensor(np.array([0.03, 0.03], np.float32)) + self.cov_step = Parameter(0, name="cov_step", requires_grad=False) + self.freq = Tensor(278, ms.int32) + self.getG = P.InsertGradientOf(self.save_gradient) + + def save_gradient(self, dout): + self.cov_step = self.cov_step + self.freq + return dout + + def construct(self, x): + self.gather(self.damping, self.cov_step, 0) + out = P.ReLU()(x) + out = self.getG(out) + out = self.getG(out) + return out + + net = InsertGradientNet() + input_data = np.array([[1.2, 2.1], [2.2, 3.2]]).astype(np.float32) + grad_net = grad_all_list(net, ParameterTuple(net.trainable_params())) + print(grad_net(Tensor(input_data))) + + +def test_user_defined_bprop(): + class UserDefinedNet(nn.Cell): + def __init__(self): + super().__init__() + self.print = P.Print() + + def construct(self, x, y): + out = x * y + return out + + def bprop(self, x, y, out, dout): + self.print(out) + out = x * y + self.print(out) + self.print(dout) + return y, x + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.parameter_tuple = ParameterTuple(net.trainable_params()) + + def construct(self, x, y): + return grad_all_list(self.net, self.parameter_tuple)(x, y) + + user_defined_net = UserDefinedNet() + net = GradNet(user_defined_net) + x = Tensor(np.array([2.0], np.float32)) + y = Tensor(np.array([3.0], np.float32)) + print(net(x, y)) + + +# user defined bprop don't have the same size of parameters with primal's +def test_user_defined_bad_bprop(): + class UserDefinedNet(nn.Cell): + def __init__(self): + super().__init__() + self.print = P.Print() + + def construct(self, x, y): + out = x * y + return out + + def bprop(self, x, out, dout): + self.print(out) + out = x + self.print(out) + self.print(dout) + return x, x + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + self.parameter_tuple = ParameterTuple(net.trainable_params()) + + def construct(self, x, y): + return grad_all_list(self.net, self.parameter_tuple)(x, y) + + user_defined_net = UserDefinedNet() + net = GradNet(user_defined_net) + x = Tensor(np.array([2.0], np.float32)) + y = Tensor(np.array([3.0], np.float32)) + with pytest.raises(TypeError): + net(x, y) + + +# shoul compile success and Print in presented in the final function graph. +def test_unused_var(): + class UnusedVar(nn.Cell): + def __init__(self): + super().__init__() + self.print = P.Print() + + def construct(self, x, y): + shape1 = self.get_shape(x) + out = x + for _ in range(shape1): + out = out + y + return out + + def get_shape(self, x): + self.print(x) + _, c, _, _ = F.shape(x) + return c + + net = UnusedVar() + x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32) + y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32) + print(net(x, y)) + + +# shoul compile success and Print in presented in the final function graph. +def test_hof_unused_var(): + class UnusedVar(nn.Cell): + def __init__(self): + super().__init__() + self.print = P.Print() + + def construct(self, x, y): + shape1 = self.hof_get_shape(self.get_shape, x) + out = x + for _ in range(shape1): + out = out + y + return out + + def hof_get_shape(self, hof, x): + return hof(x) + + def get_shape(self, x): + self.print(x) + _, c, _, _ = F.shape(x) + return c + + net = UnusedVar() + x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32) + y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32) + print(net(x, y)) + + +# shoul compile success and Print in presented in the final function graph. +def test_partial_hof_unused_var(): + class UnusedVar(nn.Cell): + def __init__(self): + super().__init__() + self.print = P.Print() + + def construct(self, x, y): + shape1 = self.hof_get_shape(x)() + out = x + for _ in range(shape1): + out = out + y + return out + + def hof_get_shape(self, x): + return F.partial(self.get_shape, x) + + def get_shape(self, x): + self.print(x) + _, c, _, _ = F.shape(x) + return c + + net = UnusedVar() + x = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32) + y = Tensor(np.ones(shape=[3, 2, 1, 2]), ms.float32) + print(net(x, y)) + + +# should compile success without endless loop. +def test_while_if(): + class WhileIfNet(nn.Cell): + def __init__(self): + super().__init__() + self.zero = Tensor(np.zeros([1]).astype(np.float32)) + self.param = Parameter(Tensor(np.zeros([1]).astype(np.float32))) + + def construct(self, idx, end, x): + out = self.zero + while idx < end: + if x < end: + out = out + self.param * 2 + else: + out = out + self.param + idx = idx + 1 + return out + + idx = Tensor(np.array(0), dtype=ms.int32) + end = Tensor(np.array(5), dtype=ms.int32) + x = Tensor(np.zeros([1]).astype(np.float32)) + m = WhileIfNet() + m(idx, end, x) + + +# should compile success without zeros_like_tensor args mismatch, the generated graph files +# should not contain env_getitem or env_setitem. +# InsertGradientOf primitive will make func_graph bprop_construct had BackPropAutoMonad flag set, +# so all graph it used will be checked if any side effect it has, so the hyper_map_zeros_like +# will have U as parameter, but the call site zeros_like(fv) don't have U argument. +def test_grad_fv_and_insert_gradient_of(): + class FvAndInsertGradientNet(nn.Cell): + def __init__(self): + super(FvAndInsertGradientNet, self).__init__() + self.gather = P.GatherV2() + self.damping = Tensor(np.array([0.03, 0.03], np.float32)) + self.cov_step = Parameter(0, name="cov_step", requires_grad=False) + self.freq = Tensor(278, ms.int32) + self.getG = P.InsertGradientOf(self.save_gradient) + + self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z') + + def save_gradient(self, dout): + self.cov_step = self.cov_step + self.freq + return dout + + def construct(self, *inputs): + # fv self.z from construct_wrapper + x, = inputs + self.z = x + + # insert_gradient_of + self.gather(self.damping, self.cov_step, 0) + out = self.getG(x) + return out + + net = FvAndInsertGradientNet() + input_data = Tensor(np.array([1.0], np.float32)) + # if use grad_all_list, the generated graph will have env_setitem + # as gradient for inputs is constant zero, so it will depend on result of grad. + grad_net = grad_by_list(net, ParameterTuple(net.trainable_params())) + print(grad_net(input_data)) + + +# should compile success as cnode with Partial primitive will not bind an additional U monad. +def test_partial_parameter(): + z = Parameter(Tensor(np.array([True], np.bool_)), name='z') + + class PartialNet(nn.Cell): + def __init__(self, input_z): + super().__init__() + self.input = input_z + + def construct(self): + # getattr of all will be convert to Partial + out = self.input.all(axis=()) + return out + + net = PartialNet(z) + print(net()) diff --git a/tests/ut/python/pipeline/parse/test_parse.py b/tests/ut/python/pipeline/parse/test_parse.py index 947d094fc9..4b7ebb59fb 100644 --- a/tests/ut/python/pipeline/parse/test_parse.py +++ b/tests/ut/python/pipeline/parse/test_parse.py @@ -37,6 +37,12 @@ from ...ut_filter import non_graph_engine # pylint: disable=W0613,W0612 # W0613: unused-argument +@pytest.fixture(name='enable_check_bprop') +def fixture_enable_check_bprop(): + context.set_context(check_bprop=True) + yield + context.set_context(check_bprop=False) + grad_all = C.GradOperation(get_all=True) @@ -146,8 +152,7 @@ def test_net_with_ndarray(): net(ms.Tensor(input_data)) -def test_bprop_with_wrong_output_num(): - context.set_context(check_bprop=True) +def test_bprop_with_wrong_output_num(enable_check_bprop): class BpropWithWrongOutputNum(PrimitiveWithInfer): @prim_attr_register def __init__(self): @@ -182,8 +187,7 @@ def test_bprop_with_wrong_output_num(): grad_all(BpropWithWrongOutputNumCell())(Tensor(np.array(1).astype(np.int32)), Tensor(np.array(2).astype(np.int32))) -def test_bprop_with_wrong_output_type(): - context.set_context(check_bprop=True) +def test_bprop_with_wrong_output_type(enable_check_bprop): class BpropWithWrongOutputType(PrimitiveWithInfer): @prim_attr_register def __init__(self): @@ -218,8 +222,7 @@ def test_bprop_with_wrong_output_type(): grad_all(BpropWithWrongOutputTypeCell())(Tensor(np.ones([64, 10]).astype(np.int32))) -def test_bprop_with_wrong_output_shape(): - context.set_context(check_bprop=True) +def test_bprop_with_wrong_output_shape(enable_check_bprop): class BpropWithWrongOutputShape(PrimitiveWithInfer): @prim_attr_register def __init__(self): @@ -307,7 +310,7 @@ class Assign(nn.Cell): return self.cov_step -def test_assign(): +def test_assign(enable_check_bprop): context.set_context(mode=context.GRAPH_MODE) net = Assign() input_data = ms.Tensor(np.array(1).astype(np.int32)) diff --git a/tests/ut/python/pynative_mode/test_cont_cases.py b/tests/ut/python/pynative_mode/test_cont_cases.py index a3d295e9ad..5d1186e6e4 100644 --- a/tests/ut/python/pynative_mode/test_cont_cases.py +++ b/tests/ut/python/pynative_mode/test_cont_cases.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """ test control ops """ +import pytest import numpy as np from mindspore import dtype as ms from mindspore import Tensor @@ -30,8 +31,11 @@ grad_by_list = C.GradOperation(get_by_list=True) grad_all = C.GradOperation(get_all=True) -def setup_module(): - context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False) +@pytest.fixture(scope="module", autouse=True) +def setup_teardown(): + context.set_context(mode=context.PYNATIVE_MODE, precompile_only=True) + yield + context.set_context(mode=context.GRAPH_MODE, precompile_only=False) def test_while_with_param_forward_with_const_branch(): @@ -683,7 +687,7 @@ def test_if_by_if_forward(): def test_if_by_if_forward_control_tuple_switch(): - """tuple_get from swtich op will generate new switch inside to eliminate tuple_get""" + """tuple_get from switch op will generate new switch inside to eliminate tuple_get""" class Branch3Net(nn.Cell): def __init__(self): super().__init__() diff --git a/tests/ut/python/pynative_mode/test_graph_param_cases.py b/tests/ut/python/pynative_mode/test_graph_param_cases.py index 96a1ab25fc..55b2ba566a 100644 --- a/tests/ut/python/pynative_mode/test_graph_param_cases.py +++ b/tests/ut/python/pynative_mode/test_graph_param_cases.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import pytest import numpy as np from mindspore import RowTensor from mindspore import context, nn, Tensor, ParameterTuple @@ -20,8 +21,11 @@ from mindspore.common import ms_function from mindspore.ops import composite as C -def setup_module(): - context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False) +@pytest.fixture(scope="module", autouse=True) +def setup_teardown(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + yield + context.set_context(mode=context.GRAPH_MODE) class _Grad(nn.Cell): diff --git a/tests/ut/python/pynative_mode/test_multi_grad.py b/tests/ut/python/pynative_mode/test_multi_grad.py index a59bc9f3be..a6bcd4adf1 100644 --- a/tests/ut/python/pynative_mode/test_multi_grad.py +++ b/tests/ut/python/pynative_mode/test_multi_grad.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import pytest import numpy as np from mindspore import context, nn, Tensor, Parameter, ParameterTuple from mindspore.common import dtype as mstype from mindspore.ops import composite as C -def setup_module(): - context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False) +@pytest.fixture(scope="module", autouse=True) +def setup_teardown(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + yield + context.set_context(mode=context.GRAPH_MODE) class _Grad(nn.Cell): diff --git a/tests/ut/python/pynative_mode/test_sparse_pynative.py b/tests/ut/python/pynative_mode/test_sparse_pynative.py index 3568491b23..488bc12965 100644 --- a/tests/ut/python/pynative_mode/test_sparse_pynative.py +++ b/tests/ut/python/pynative_mode/test_sparse_pynative.py @@ -18,12 +18,17 @@ @Date : 2020-08-04 @Desc : test mindspore sparse pynative """ +import pytest import mindspore as ms import mindspore.nn as nn from mindspore import context, Tensor, RowTensor, SparseTensor from mindspore.ops import composite as C -context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=True) +@pytest.fixture(scope="module", autouse=True) +def setup_teardown(): + context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=True) + yield + context.set_context(mode=context.GRAPH_MODE, enable_sparse=False) grad_all = C.GradOperation(get_all=True) diff --git a/tests/vm_impl/array_ops_vm_impl.py b/tests/vm_impl/array_ops_vm_impl.py index 21628060cc..8fb5f3c53f 100644 --- a/tests/vm_impl/array_ops_vm_impl.py +++ b/tests/vm_impl/array_ops_vm_impl.py @@ -25,7 +25,7 @@ from .vm_interface import vm @vm_impl_getters.register(P.Assign) def vm_impl_assign(self): """Generate vm_impl function for Assign""" - def vm_impl(x, value): + def vm_impl(x, value, u=None): x.assign_value(value) return x return vm_impl @@ -323,3 +323,21 @@ def vm_impl_depend(self): return value return vm_impl + + +@vm_impl_getters.register(P.UpdateState) +def vm_impl_updatestate(self): + """Generate vm_impl function for UpdateState""" + def vm_impl(monad, expr): + return monad + + return vm_impl + + +@vm_impl_getters.register(P.Load) +def vm_impl_load(self): + """Generate vm_impl function for Load""" + def vm_impl(value, u=None): + return value + + return vm_impl