From ba65fb9f3c444714c12543590e4cfd51ca4c73bb Mon Sep 17 00:00:00 2001 From: liangzelang Date: Sat, 20 Mar 2021 11:49:56 +0800 Subject: [PATCH] Support non-tail recursive graphs --- .../aicpu/aicpu_kernel_build.cc | 9 +- .../aicpu/aicpu_kernel_metadata.cc | 64 ++--- .../aicpu/aicpu_kernel_metadata.h | 4 +- .../ascend/format_type/merge_cast_to_op.cc | 5 +- .../backend/session/anf_runtime_algorithm.cc | 44 ++++ .../backend/session/anf_runtime_algorithm.h | 2 + .../backend/session/ascend_auto_monad.cc | 223 +++++++++++++++++- mindspore/ccsrc/utils/utils.h | 5 + mindspore/core/base/core_ops.h | 6 + tests/st/control/test_cont_grad.py | 94 +++++--- 10 files changed, 380 insertions(+), 76 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc index 1b0cb28043..e34cf29824 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_build.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -92,7 +92,9 @@ bool SetIOSize(const std::shared_ptr &anf_node, const std::shared_ptrSetInputSizeList(input_size_list); - + if (output_num == 1 && HasAbstractMonad(anf_node)) { + output_num = 0; + } for (size_t i = 0; i < output_num; i++) { std::vector shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i); TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i)); @@ -229,6 +231,9 @@ void SetNodeOutputs(const std::shared_ptr &anf_node, mindspore::NodeDef MS_EXCEPTION_IF_NULL(proto); MS_EXCEPTION_IF_NULL(anf_node); size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node); + if (output_num == 1 && HasAbstractMonad(anf_node)) { + output_num = 0; + } if (output_num == 0) { MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have output. "; return; diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc index 70e6498a7b..4dd08d665d 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,32 +38,10 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector inputs_format{}; - std::vector inputs_type{}; - if (op_name == kPrint || op_name == kPack || op_name == kMeshgrid) { - size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - for (size_t input_index = 0; input_index < input_num; ++input_index) { - inputs_format.emplace_back(kOpFormat_DEFAULT); - inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); - } - } - std::vector outputs_format; - std::vector outputs_type; - size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); - for (size_t output_index = 0; output_index < output_num; ++output_index) { - outputs_format.emplace_back(kOpFormat_DEFAULT); - outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); - } - auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); - builder.SetInputsFormat(inputs_format); - builder.SetInputsDeviceType(inputs_type); - builder.SetOutputsFormat(outputs_format); - builder.SetOutputsDeviceType(outputs_type); - builder.SetProcessor(AICPU); - builder.SetKernelType(AICPU_KERNEL); - builder.SetFusionType(OPAQUE); - kernel_info_list->push_back(builder.Build()); + if (op_name == kPrint || op_name == kGetNext || op_name == kPack || op_name == kMeshgrid || + op_name == kStackInitOpName || op_name == kStackDestroyOpName || op_name == kStackPushOpName || + op_name == kStackPopOpName) { + AicpuMetadataInfoForSpecialNodes(kernel_node, kernel_info_list); return; } if (!ParseMetadata(kernel_node, op_info_ptr, AICPU, kernel_info_list)) { @@ -71,5 +49,37 @@ void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { + std::vector inputs_format{}; + std::vector inputs_type{}; + auto op_name = AnfAlgo::GetCNodeName(kernel_node); + if (op_name == kPrint || op_name == kPack || op_name == kMeshgrid || op_name == kStackInitOpName || + op_name == kStackDestroyOpName || op_name == kStackPushOpName || op_name == kStackPopOpName) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + for (size_t input_index = 0; input_index < input_num; ++input_index) { + inputs_format.emplace_back(kOpFormat_DEFAULT); + inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index)); + } + } + std::vector outputs_format; + std::vector outputs_type; + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + for (size_t output_index = 0; output_index < output_num; ++output_index) { + outputs_format.emplace_back(kOpFormat_DEFAULT); + outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index)); + } + auto builder = KernelBuildInfo::KernelBuildInfoBuilder(); + builder.SetInputsFormat(inputs_format); + builder.SetInputsDeviceType(inputs_type); + builder.SetOutputsFormat(outputs_format); + builder.SetOutputsDeviceType(outputs_type); + builder.SetProcessor(AICPU); + builder.SetKernelType(AICPU_KERNEL); + builder.SetFusionType(OPAQUE); + kernel_info_list->push_back(builder.Build()); + return; +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h index d82fa6b02b..1d9ead2f57 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h +++ b/mindspore/ccsrc/backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,8 @@ namespace mindspore { namespace kernel { void AicpuMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list); +void AicpuMetadataInfoForSpecialNodes(const CNodePtr &kernel_node, + std::vector> *kernel_info_list); } // namespace kernel } // namespace mindspore #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_AICPU_AICPU_KERNEL_META_DATA_H_ diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.cc index e9250248e4..5a0d6a6d04 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/merge_cast_to_op.cc @@ -154,7 +154,7 @@ AnfNodePtr MergeCastToNextOp(const FuncGraphPtr &graph, const CNodePtr &node, co return nullptr; } auto next_op_name = AnfAlgo::GetCNodeName(next_cnode); - if (next_op_name == prim::kPrimSend->name()) { + if (next_op_name == prim::kPrimSend->name() || next_op_name == kStackPushOpName) { return nullptr; } std::vector> kernel_info_list; @@ -229,7 +229,8 @@ AnfNodePtr MergeCastToPriorOp(const FuncGraphPtr &graph, const CNodePtr &cur_nod } std::vector> kernel_info_list; - if (AnfAlgo::GetCNodeName(prior_op) == prim::kPrimReceive->name()) { + if (AnfAlgo::GetCNodeName(prior_op) == prim::kPrimReceive->name() || + AnfAlgo::GetCNodeName(prior_op) == kStackPopOpName) { return nullptr; } kernel_query->Query(prior_op, &kernel_info_list); diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 07e1409a02..288f7c9de0 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -106,6 +106,43 @@ std::vector TransShapeToSizet(const abstract::ShapePtr &shape) { enum ShapeType { kMaxShape, kMinShape }; } // namespace +AnfNodePtr AnfRuntimeAlgorithm::MakeMonadValueNode(const KernelGraphPtr &kg) { + return kg->NewValueNode(kUMonad->ToAbstract(), kUMonad); +} + +// Convert: +// a = former(xxx) +// b = latter(x, xxx) +// To: +// a = former(xxx) +// d1 = Depend(x, a) +// b = latter(d1, xxx) +// ... +// out = Depend(out, latter) +void AnfRuntimeAlgorithm::KeepOrder(const KernelGraphPtr &kg, const AnfNodePtr &former, const AnfNodePtr &latter) { + if (latter->isa()) { + auto latter_cnode = latter->cast(); + constexpr size_t inputsize = 2; + constexpr size_t kFirstDataInputIndex = 1; + if (latter_cnode->inputs().size() < inputsize) { + return; + } + auto latter_input = latter_cnode->input(kFirstDataInputIndex); + auto depend1 = kg->NewCNode({NewValueNode(prim::kPrimDepend), latter_input, former}); + depend1->set_abstract(latter_input->abstract()); + latter_cnode->set_input(kFirstDataInputIndex, depend1); + + auto return_node = kg->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + auto depend2 = kg->NewCNode( + {NewValueNode(prim::kPrimDepend), return_node->cast()->input(kFirstDataInputIndex), latter}); + depend2->set_abstract(return_node->cast()->input(kFirstDataInputIndex)->abstract()); + kg->set_output(depend2); + MS_LOG(DEBUG) << "former: " << former->DebugString() << ", latter: " << latter->DebugString() + << ", depend1: " << depend1->DebugString() << ", depend2: " << depend2->DebugString(); + } +} + AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) { MS_EXCEPTION_IF_NULL(tuple_get_item); if (tuple_get_item->size() != kTupleGetItemInputSize) { @@ -1529,6 +1566,13 @@ bool AnfRuntimeAlgorithm::IsIndependentNode(const CNodePtr &node) { return false; } + // aicpu stack ops are not independent nodes. + if (AnfAlgo::GetCNodeName(node) == kStackInitOpName || AnfAlgo::GetCNodeName(node) == kStackDestroyOpName || + AnfAlgo::GetCNodeName(node) == kStackPopOpName || AnfAlgo::GetCNodeName(node) == kStackPushOpName) { + MS_LOG(INFO) << "AICPU stack ops should not be independent node"; + return false; + } + size_t input_nums = AnfAlgo::GetInputTensorNum(node); if (input_nums == 0) { return true; diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index f08620994b..bb0a7cb254 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -43,6 +43,8 @@ using DeviceAddress = device::DeviceAddress; using DeviceAddressPtr = device::DeviceAddressPtr; class AnfRuntimeAlgorithm { public: + static AnfNodePtr MakeMonadValueNode(const KernelGraphPtr &kg); + static void KeepOrder(const KernelGraphPtr &kg, const AnfNodePtr &former, const AnfNodePtr &latter); // get real input node of tuple_get_item static AnfNodePtr GetTupleGetItemRealInput(const CNodePtr &tuple_get_item); static size_t GetTupleGetItemOutIndex(const CNodePtr &tuple_get_item); diff --git a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc index a06624cdc7..5d4ae3f58c 100644 --- a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc +++ b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc @@ -145,6 +145,9 @@ struct CallSite { // Call/Switch/SwitchLayer CNodePtr cnode; + // CNode after transferring to LabelGoto/LabelSwitch/LabelSet. + CNodePtr conversion_cnode; + // The last monad before call. AnfNodePtr last_monad = nullptr; @@ -286,6 +289,12 @@ class AscendAutoMonadContext : public BaseContext { const KernelGraphPtr &TopGraph() const { return top_graph_; } + // Has already created an stack. + const bool HasInitedStack() const { return inited_stack_; } + + // Set flag to indicate whether has already created an stack or not. + void SetInitedStack(bool flag) { inited_stack_ = flag; } + // Map kernel_graph to its call info. OrderedMap call_info_map; @@ -298,6 +307,9 @@ class AscendAutoMonadContext : public BaseContext { // Current label id. uint32_t label_id_ = 0; + + // Create an stack for multi-call and non-tail recursion. + bool inited_stack_ = false; }; // @@ -605,16 +617,22 @@ class AscendAutoMonadConverter { private: AscendAutoMonadConverter(const KernelGraphPtr &kg, AscendAutoMonadContext *context, CallInfo *call_info) - : kernel_graph_(kg), context_(*context), call_info_(*call_info) {} + : kernel_graph_(kg), + context_(*context), + call_info_(*call_info), + name_index_(0), + need_stackops_(call_info->recursive) {} ~AscendAutoMonadConverter() = default; void Run() { + // Create an stack + InitStack(); // Setup entry label if found. SetupEntryLabel(); // Handle call sites. for (auto &call_site : call_info_.call_sites) { - HandleCallSite(call_site); + HandleCallSite(&call_site); } // Handle return points. HandleReturnPoints(); @@ -622,20 +640,148 @@ class AscendAutoMonadConverter { if (monad_) { MakeMonadDepend(); } + // Handle recursive call. + kernel_graph_->SetExecOrderByDefault(); + for (auto &call_site : call_info_.call_sites) { + if (need_stackops_ && call_site.recursive) { + MS_LOG(INFO) << "graph:" << kernel_graph_->ToString() << ", loop call_site:" << call_site.cnode->DebugString(); + InsertStackOps(call_site); + } + } + } + + // Create a Stack for StackOps if needed. + void InitStack() { + if (!context_.HasInitedStack() && need_stackops_) { + auto top_graph = context_.TopGraph(); + auto exec_order = top_graph->execution_order(); + auto stack_init = StackInit(top_graph); + AnfAlgo::KeepOrder(top_graph, stack_init, *exec_order.begin()); + auto stack_destroy = StackDestroy(top_graph); + AnfAlgo::KeepOrder(top_graph, *exec_order.rbegin(), stack_destroy); + top_graph->SetExecOrderByDefault(); + context_.SetInitedStack(true); + } + } + + // Insert StackOps for call_site in the recursive graph. + void InsertStackOps(const CallSite &call_site) { + auto call_point = call_site.conversion_cnode; + auto exec_order = kernel_graph_->execution_order(); + std::vector before_nodes; + std::vector stack_pushs; + bool find_call_point = false; + for (auto &node : exec_order) { + auto node_name = AnfAlgo::GetCNodeName(node); + if (node == call_point) { + find_call_point = true; + continue; + } + if (!find_call_point) { + if (node_name == kLabelGotoOpName || node_name == kLabelSwitchOpName || node_name == kLabelSetOpName || + node_name == prim::kPrimAssign->name()) { + MS_LOG(DEBUG) << "Ignore goto/switch/set/assign ops"; + } else { + before_nodes.push_back(node); + MS_LOG(DEBUG) << "push back node:" << node->DebugString(); + } + continue; + } + if (node->size() == 0 || node_name == kLabelGotoOpName || node_name == kLabelSetOpName || + node_name == prim::kPrimAssign->name()) { + continue; + } + FindInputNode(before_nodes, node, &stack_pushs); + } + InsertStackPush(kernel_graph_, call_point, stack_pushs); + } + + // Find nodes which need StackOps, and insert StackOps for node. + void FindInputNode(const std::vector &before_nodes, const CNodePtr &node, + std::vector *stack_pushs) { + uint32_t start_index = 1; + if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimAssign)) { + start_index = 2; + } + // auto node_inputs = node->inputs(); + for (uint32_t i = start_index; i < node->inputs().size(); i++) { + auto node_input = node->input(i); + // not need to save monad. + if (HasAbstractMonad(node_input)) { + continue; + } + MS_LOG(DEBUG) << "check node input[" << i << "]: " << node_input->DebugString(); + if (node_input->isa()) { + MS_LOG(DEBUG) << "node_input:" << node_input->DebugString() << " is a param"; + CNodePtr stack_pop = InsertStackPop(kernel_graph_, node_input, stack_pushs); + node->set_input(i, stack_pop); + KeepOrderForStackPop(kernel_graph_, stack_pop, node); + continue; + } + auto iter = std::find_if(before_nodes.begin(), before_nodes.end(), + [node_input](auto before_node) { return before_node == node_input; }); + if (iter != before_nodes.end()) { + CNodePtr stack_pop = InsertStackPop(kernel_graph_, *iter, stack_pushs); + node->set_input(i, stack_pop); + KeepOrderForStackPop(kernel_graph_, stack_pop, node); + } + } + } + + // Create StackOps for node_input. + CNodePtr InsertStackPop(const KernelGraphPtr &kg, const AnfNodePtr &node_input, std::vector *stack_pushs) { + auto stack_push = StackPush(node_input); + stack_pushs->emplace_back(stack_push); + auto stack_pop = StackPop(); + stack_pop->set_abstract(node_input->abstract()); + return stack_pop; + } + + // Arrange StackPushs according to the rules of the last pop-up StackPush first, + // while ensuring that the last StackPush node is next to the jump_node. + void InsertStackPush(const KernelGraphPtr &kg, const CNodePtr &jump_node, const std::vector &stack_pushs) { + MS_LOG(DEBUG) << "There are " << stack_pushs.size() << " stack_push ops"; + if (stack_pushs.size() < 1) { + return; + } + for (uint32_t i = 1; i < stack_pushs.size(); i++) { + AnfAlgo::KeepOrder(kg, stack_pushs[i], stack_pushs[i - 1]); + } + auto nodes = kg->execution_order(); + auto node_iter = std::find(nodes.begin(), nodes.end(), jump_node); + AnfAlgo::KeepOrder(kg, stack_pushs[0], jump_node); + if (node_iter != nodes.begin()) { + AnfAlgo::KeepOrder(kg, *(node_iter - 1), *stack_pushs.rbegin()); + } } - void HandleCallSite(const CallSite &call_site) { + // Ensure StackPop is next to the jump_node. + void KeepOrderForStackPop(const KernelGraphPtr &kg, const CNodePtr &pop, const CNodePtr &jump_node) { + auto nodes = kg->execution_order(); + auto node_iter = std::find(nodes.cbegin(), nodes.cend(), jump_node); + if (node_iter == nodes.cend()) { + MS_LOG(EXCEPTION) << "Cannot find node: " << jump_node->DebugString(); + } + // Insert between jump_node-1 and jump_node. + if (node_iter != nodes.begin()) { + CNodePtr node = *(node_iter - 1); + AnfAlgo::KeepOrder(kg, node, pop); + } + AnfAlgo::KeepOrder(kg, pop, jump_node); + } + + void HandleCallSite(CallSite *call_site) { // Update last_monad_. - last_monad_ = call_site.last_monad; + last_monad_ = call_site->last_monad; // The call/switch/switch_layer cnode. - auto &cnode = call_site.cnode; + auto &cnode = call_site->cnode; // Get branches of the call_site. // for call, there is one branch; // for switch, the first one is true branch; // for switch_layer, the first one is 0 branch. - auto &branches = call_site.callees; + auto &branches = call_site->callees; // Link arguments and find labels for branches. std::vector graphes; @@ -664,13 +810,14 @@ class AscendAutoMonadConverter { // Create LabelGoto or LabelSwitch node. auto label_goto_switch = MakeLabelGotoSwitch(cnode, graphes, labels); + call_site->conversion_cnode = label_goto_switch; // Setup return label and output if required. - if (call_site.return_label != kNoLabel) { - auto label_node = LabelSet(call_site.return_label); - AnfNodePtr output = call_site.out_param; + if (call_site->return_label != kNoLabel) { + auto label_node = LabelSet(call_site->return_label); + AnfNodePtr output = call_site->out_param; MS_EXCEPTION_IF_NULL(output); - const bool is_single_call = call_site.label_indexes.empty(); + const bool is_single_call = call_site->label_indexes.empty(); if (is_single_call) { // For single call, let output depend on the label node, // this ensures the return label is set before output is used. @@ -688,7 +835,7 @@ class AscendAutoMonadConverter { } // If no return label required, it should be a tail call. - if (!call_site.tail) { + if (!call_site->tail) { MS_LOG(EXCEPTION) << "Return label not set for non-tail call " << cnode->DebugString(); } // For tail calls, replace origin call node with label_goto/label_switch. @@ -697,8 +844,8 @@ class AscendAutoMonadConverter { } // Assign label indexes to label parameters for a call site. - void AssignLabelIndexes(const CallSite &call_site) { - for (auto &[label_param, label_index] : call_site.label_indexes) { + void AssignLabelIndexes(const CallSite *call_site) { + for (auto &[label_param, label_index] : call_site->label_indexes) { auto index_value = GetIndexValueNode(label_index); auto assign = Assign(label_param, index_value, false, false, false); monad_ = UpdateState(GetMonad(), assign); @@ -1020,6 +1167,50 @@ class AscendAutoMonadConverter { AnfAlgo::SetNodeAttr(kAttrChildGraph, MakeValue(graphs), node); } + // Make a StackInit node. + CNodePtr StackInit(const KernelGraphPtr &kg) { + auto monad = AnfAlgo::MakeMonadValueNode(kg); + auto stack_init = NewPrimitive(prim::kPrimStackInit); + auto cnode = kg->NewCNode({stack_init, monad}); + AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(0), cnode); + cnode->set_abstract(monad->abstract()); + return cnode; + } + + // Make a StackDestroy node. + CNodePtr StackDestroy(const KernelGraphPtr &kg) { + auto monad = AnfAlgo::MakeMonadValueNode(kg); + auto stack_destroy = NewPrimitive(prim::kPrimStackDestroy); + auto cnode = kg->NewCNode({stack_destroy, monad}); + AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(0), cnode); + cnode->set_abstract(monad->abstract()); + return cnode; + } + + // Make a StackPush node. + CNodePtr StackPush(const AnfNodePtr &input) { + auto monad = AnfAlgo::MakeMonadValueNode(kernel_graph_); + auto stack_push = NewPrimitive(prim::kPrimStackPush); + auto cnode = kernel_graph_->NewCNode({stack_push, input, monad}); + AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(0), cnode); + auto op_name = std::to_string(kernel_graph_->graph_id()) + "_stack_push_" + std::to_string(name_index_++); + AnfAlgo::SetNodeAttr(kAttrStackOpName, MakeValue(op_name), cnode); + cnode->set_abstract(monad->abstract()); + return cnode; + } + + // Make a StackPop node. + CNodePtr StackPop() { + auto monad = AnfAlgo::MakeMonadValueNode(kernel_graph_); + auto stack_pop = NewPrimitive(prim::kPrimStackPop); + auto cnode = kernel_graph_->NewCNode({stack_pop, monad}); + AnfAlgo::SetNodeAttr(kAttrIndex, MakeValue(0), cnode); + auto op_name = std::to_string(kernel_graph_->graph_id()) + "_stack_pop_" + std::to_string(name_index_++); + AnfAlgo::SetNodeAttr(kAttrStackOpName, MakeValue(op_name), cnode); + cnode->set_abstract(monad->abstract()); // need to refresh output's abstract(). + return cnode; + } + private: const KernelGraphPtr &kernel_graph_; AscendAutoMonadContext &context_; @@ -1038,6 +1229,12 @@ class AscendAutoMonadConverter { // Index value node cache for reuse. std::map index_nodes_; + + // The index of stackops name. + uint32_t name_index_; + + // The flag which indicates to insert stackops. + bool need_stackops_; }; constexpr size_t kAssignTargetIndex = 1; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 6077ea46e9..75ef38cbd7 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -116,6 +116,10 @@ constexpr auto kApplyProximalAdagradOpName = "ApplyProximalAdagrad "; constexpr auto kApplyProximalGradientDescentOpName = "ApplyProximalGradientDescent"; constexpr auto kApplyRMSPropOpName = "ApplyRMSProp"; constexpr auto kTransDataOpName = "TransData"; +constexpr auto kStackInitOpName = "StackInit"; +constexpr auto kStackPushOpName = "StackPush"; +constexpr auto kStackPopOpName = "StackPop"; +constexpr auto kStackDestroyOpName = "StackDestroy"; constexpr auto kBNTrainingUpdateGradOpName = "BNTrainingUpdateGrad"; constexpr auto kBNTrainingReduceGradOpName = "BNTrainingReduceGrad"; constexpr auto kSquareSumV1OpName = "SquareSumV1"; @@ -380,6 +384,7 @@ constexpr auto kAttrRankSize = "rank_size"; constexpr auto kAttrPadDimSize = "pad_dim_size"; constexpr auto kAttrPaddings = "paddings"; constexpr auto kAttrNumSegments = "num_segments"; +constexpr auto kAttrStackOpName = "stack_op_name"; constexpr auto kAttrBegin = "begin"; constexpr auto kAttrSize = "size"; constexpr auto kAttrIsDynamicShape = "is_dynamic_shape"; diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index ec46ffc9d7..903e1e0091 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -105,6 +105,12 @@ inline const PrimitivePtr kPrimLabelGoto = std::make_shared("LabelGot inline const PrimitivePtr kPrimLabelSwitch = std::make_shared("LabelSwitch"); inline const PrimitivePtr kPrimLabelSet = std::make_shared("LabelSet"); +// Stack ops +inline const PrimitivePtr kPrimStackInit = std::make_shared("StackInit"); +inline const PrimitivePtr kPrimStackDestroy = std::make_shared("StackDestroy"); +inline const PrimitivePtr kPrimStackPush = std::make_shared("StackPush"); +inline const PrimitivePtr kPrimStackPop = std::make_shared("StackPop"); + // Arrays inline const PrimitivePtr kPrimBroadcastTo = std::make_shared("BroadcastTo"); inline const PrimitivePtr kPrimScalarToArray = std::make_shared("scalar_to_array"); diff --git a/tests/st/control/test_cont_grad.py b/tests/st/control/test_cont_grad.py index 8fa917461a..00bf7534c7 100644 --- a/tests/st/control/test_cont_grad.py +++ b/tests/st/control/test_cont_grad.py @@ -1,4 +1,4 @@ -# Copyright 2020 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ grad_by_list = C.GradOperation(get_by_list=True) grad_all = C.GradOperation(get_all=True) -def test_while_forward(): +def test_while_grad(): class MyWhileNet(nn.Cell): def __init__(self): super().__init__() @@ -46,31 +46,71 @@ def test_while_forward(): x[idx, :, 0:2] = max_num idx = idx + 1 return x + + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return grad_all(self.net)(*inputs) # graph mode context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - net = MyWhileNet() + while_net = MyWhileNet() + net = GradNet(while_net) idx = Tensor(np.array(0), dtype=ms.int32) end = Tensor(np.array(2), dtype=ms.int32) - x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) + x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) graph_output = net(idx, end, x) - #pynative mode + # pynative mode context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") pynative_output = net(idx, end, x) - assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) + assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) + assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001) + assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001) + +def test_while_with_const_param_grad(): + class MyWhileNet(nn.Cell): + def __init__(self): + super().__init__() + self.mul = P.Mul() + self.add = P.Add() + def construct(self, x, y): + while x < y: + z = self.mul(x, x) + x = self.add(z, 1) + return x -def test_while_grad(): + class GradNet(nn.Cell): + def __init__(self, net): + super(GradNet, self).__init__() + self.net = net + + def construct(self, *inputs): + return grad_all(self.net)(*inputs) + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + while_net = MyWhileNet() + net = GradNet(while_net) + idx = Tensor([1.1], dtype=ms.float32) + end = Tensor([8.0], dtype=ms.float32) + graph_output = net(idx, end) + expect_one = np.array([1.14433983e+02], dtype=np.float32) + expect_two = np.array([0], dtype=np.float32) + assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001) + assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001) + +def test_while_with_variable_grad(): class MyWhileNet(nn.Cell): def __init__(self): super().__init__() - self.max = P.ReduceMax() + self.mul = P.Mul() + self.add = P.Add() - def construct(self, idx, end, x): - while idx < end: - part = x[idx, :, :] - max_num = self.max(part) - x[idx, :, 0:2] = max_num - idx = idx + 1 + def construct(self, x, y): + while x < y: + z = self.mul(x, x) + x = self.add(z, y) return x class GradNet(nn.Cell): @@ -80,20 +120,16 @@ def test_while_grad(): def construct(self, *inputs): return grad_all(self.net)(*inputs) - # graph mode context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = GradNet(while_net) - idx = Tensor(np.array(0), dtype=ms.int32) - end = Tensor(np.array(2), dtype=ms.int32) - x = Tensor(np.random.randn(2, 2, 2).astype(np.float32), dtype=ms.float32) - graph_output = net(idx, end, x) - # pynative mode - context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") - pynative_output = net(idx, end, x) - assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) - assert np.allclose(graph_output[1].asnumpy(), pynative_output[1].asnumpy(), 0.0001, 0.0001) - assert np.allclose(graph_output[2].asnumpy(), pynative_output[2].asnumpy(), 0.0001, 0.0001) + idx = Tensor([1.1], dtype=ms.float32) + end = Tensor([8.0], dtype=ms.float32) + graph_output = net(idx, end) + expect_one = np.array([2.20000005e+00], dtype=np.float32) + expect_two = np.array([1.00000000e+00], dtype=np.float32) + assert np.allclose(graph_output[0].asnumpy(), expect_one, 0.0001, 0.0001) + assert np.allclose(graph_output[1].asnumpy(), expect_two, 0.0001, 0.0001) def test_while_with_param_forward(): class MyWhileNet(nn.Cell): @@ -153,7 +189,6 @@ def test_while_endless_case(): pynative_output = net(idx, end, x) assert np.allclose(graph_output.asnumpy(), pynative_output.asnumpy(), 0.0001, 0.0001) - def test_while_with_param_grad(): class MyWhileNet(nn.Cell): def __init__(self): @@ -180,7 +215,6 @@ def test_while_with_param_grad(): def construct(self, a, b, c): return grad_by_list(self.net, self.weights)(a, b, c) - # graph mode context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") while_net = MyWhileNet() net = GradNet(while_net) @@ -188,10 +222,8 @@ def test_while_with_param_grad(): end = Tensor(np.array(2), dtype=ms.int32) x = Tensor(np.arange(8).reshape(2, 2, 2).astype(np.float32), dtype=ms.float32) graph_output = net(idx, end, x) - # pynative mode - context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") - pynative_output = net(idx, end, x) - assert np.allclose(graph_output[0].asnumpy(), pynative_output[0].asnumpy(), 0.0001, 0.0001) + expect = np.array([[[2, 2], [2, 2]], [[2, 2], [2, 2]]], dtype=np.int32) + assert np.allclose(graph_output[0].asnumpy(), expect, 0.0001, 0.0001) def test_while_with_param_forward_with_const_branch(): class MyWhileNet(nn.Cell):