From 70abe362f2d9502572d792790584e49d28db2146 Mon Sep 17 00:00:00 2001 From: leilei_snow Date: Mon, 6 Jul 2020 07:07:32 +0000 Subject: [PATCH] add case process --- mindspore/ccsrc/utils/convert_utils.cc | 14 +++++++ mindspore/ccsrc/utils/convert_utils.h | 1 + mindspore/ccsrc/vm/backend.cc | 1 + mindspore/ccsrc/vm/backend.h | 1 + mindspore/ccsrc/vm/transform.cc | 51 +++++++++++++++++++++++++- mindspore/ccsrc/vm/transform.h | 1 + mindspore/ccsrc/vm/vm.cc | 29 +++++++++++++++ mindspore/ccsrc/vm/vm.h | 11 ++++-- 8 files changed, 103 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index 8cb071b769..29f45709c8 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -230,6 +230,20 @@ bool ValueToBool(const ValuePtr &v, bool *value) { return true; } +bool BaseRefToInt(const ValuePtr &v, int *value) { + MS_EXCEPTION_IF_NULL(v); + if (v->isa()) { + auto tensor = v->cast(); + (void)tensor->data_sync(); + int *tensor_data = static_cast(tensor->data_c()); + auto vb = tensor_data[0]; + *value = vb; + return true; + } + MS_LOG(ERROR) << "Index must be tensor type."; + return false; +} + bool BaseRefToBool(const BaseRef &v, bool *value) { if (utils::isa(v)) { return ValueToBool(utils::cast(v), value); diff --git a/mindspore/ccsrc/utils/convert_utils.h b/mindspore/ccsrc/utils/convert_utils.h index 40c3e88c5c..a6c9052eae 100644 --- a/mindspore/ccsrc/utils/convert_utils.h +++ b/mindspore/ccsrc/utils/convert_utils.h @@ -42,6 +42,7 @@ using TensorPtr = std::shared_ptr; py::object AnyToPyData(const Any &value); py::object BaseRefToPyData(const BaseRef &value); bool BaseRefToBool(const BaseRef &in, bool *out); +bool BaseRefToInt(const ValuePtr &v, int *value); bool ValueToBool(const ValuePtr &in, bool *out); py::object ValuePtrToPyData(const ValuePtr &value); diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 47bc69bbbb..88a07c7c12 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -32,6 +32,7 @@ namespace mindspore { namespace compile { bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); } +bool Backend::GetIndex(const BaseRef &c, int *const value) { return BaseRefToInt(utils::cast(c), value); } LinConvertResult MsBackend::GetMultiGraphRun(const FuncGraphPtr &g) { // multi_graph merge to one, big graph have paramters in begin and only have one output diff --git a/mindspore/ccsrc/vm/backend.h b/mindspore/ccsrc/vm/backend.h index 3a93cf930f..c8d0696fa4 100644 --- a/mindspore/ccsrc/vm/backend.h +++ b/mindspore/ccsrc/vm/backend.h @@ -46,6 +46,7 @@ class Backend { virtual void SimulateRun(FinalVMPtr, FuncGraphPtr) {} virtual SwitchCondStatus SetSimuCond(const BaseRef &, bool) { return kCondOk; } virtual bool GetCond(const BaseRef &c, bool *value); + virtual bool GetIndex(const BaseRef &c, int *value); virtual void SetSwitchGraph() {} virtual void SetSwitchActive(const BaseRef &, bool) {} virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {} diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 80d2fc9df9..e145a55bbd 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -46,8 +46,9 @@ using TypedPrimitiveAbstractClosurePtr = std::shared_ptr nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, prim::kPrimMakeTuple, prim::kPrimBpropCut}; const std::vector &GetMsNonlinearOps() { - static const std::vector ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, - prim::kPrimBpropCut}; + static const std::vector ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, + prim::kPrimSwitch, prim::kPrimMakeTuple, + prim::kPrimBpropCut, prim::kPrimSwitchLayer}; return ms_nonlinear_ops; } @@ -187,6 +188,30 @@ std::vector SplitSort(const FuncGraphPtr &graph, const std::string & std::reverse(result.begin(), result.end()); return result; } + +bool IsSubGraph(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + if (inputs.empty()) { + MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; + } + + AnfNodePtr fn = inputs[0]; + MS_EXCEPTION_IF_NULL(fn); + if (!IsValueNode(fn)) { + return false; + } + auto node_prim = GetValueNode(fn); + if (node_prim->name() == prim::kPrimPartial->name()) { + return true; + } + } else if (IsValueNode(node)) { + return true; + } + return false; +} } // namespace CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector &cut_list) @@ -235,6 +260,15 @@ bool CompileGraph::IsCut(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(ms_context); ms_context->set_enable_pynative_hook(true); } + + if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) { + if (inputs.size() < 2) { + return false; + } + auto ret = IsSubGraph(inputs[1]); + return ret; + } + return true; } } @@ -466,6 +500,8 @@ int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) } else if (IsPrimitive(fn, prim::kPrimSwitch)) { AddSwitch(node); AddSinkSwitch(node); + } else if (IsPrimitive(fn, prim::kPrimSwitchLayer)) { + AddSwitchLayer(node); } else if (IsPrimitive(fn, prim::kPrimMakeTuple)) { AddMakeTuple(node); } else { @@ -622,6 +658,17 @@ void CompileGraph::AddSwitch(const CNodePtr &node) { AddInst(Instruction::kSwitch, args); } +void CompileGraph::AddSwitchLayer(const CNodePtr &node) { + auto inputs = node->inputs(); + if (inputs.size() != 3) { + MS_LOG(EXCEPTION) << "Switch layer must have index and branches."; + } + VectorRef args; + args.emplace_back(Ref(inputs[1])); + args.emplace_back(Ref(inputs[2])); + AddInst(Instruction::kSwitchLayer, args); +} + void CompileGraph::AddReturn(const CNodePtr &node) { VectorRef args; if (backend_->simu_flag()) { diff --git a/mindspore/ccsrc/vm/transform.h b/mindspore/ccsrc/vm/transform.h index a02478fc1b..55c32ea4e3 100644 --- a/mindspore/ccsrc/vm/transform.h +++ b/mindspore/ccsrc/vm/transform.h @@ -90,6 +90,7 @@ class CompileGraph { void AddPartial(const CNodePtr &node); void AddMakeTuple(const CNodePtr &node); void AddSwitch(const CNodePtr &node); + void AddSwitchLayer(const CNodePtr &node); void AddReturn(const CNodePtr &node); void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim); void AddInput(const AnfNodePtr &node); diff --git a/mindspore/ccsrc/vm/vm.cc b/mindspore/ccsrc/vm/vm.cc index c73d41df6c..f65b8bef4e 100644 --- a/mindspore/ccsrc/vm/vm.cc +++ b/mindspore/ccsrc/vm/vm.cc @@ -480,6 +480,35 @@ void FinalVM::InstSwitch(const VectorRef &args) { MS_LOG(DEBUG) << "End"; } +void FinalVM::InstSwitchLayer(const VectorRef &args) { + MS_LOG(DEBUG) << "Start"; + const size_t args_size = 2; + if (args.size() != args_size) { + MS_LOG(ERROR) << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " << args.size() + << "."; + return; + } + + int idx = utils::cast(args[0]); + VectorRef branches = utils::cast(Ref(utils::cast(args[1]))); + int size = static_cast(branches.size()); + + BaseRef index = Ref(idx); + int idx_value = 0; + if (!backend_->GetIndex(index, &idx_value)) { + MS_LOG(EXCEPTION) << "Not supported type to be casted to int."; + } + if (idx_value < 0) { + // Add support negative index range [-size, -1]. + idx_value += size; + } + if (idx_value < 0 || idx_value >= size) { + MS_LOG(EXCEPTION) << __FUNCTION__ << " given index " << idx_value << " out of range."; + } + Push(branches[idx_value]); + MS_LOG(DEBUG) << "End"; +} + void FinalVM::InstTuple(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; VectorRef tuple; diff --git a/mindspore/ccsrc/vm/vm.h b/mindspore/ccsrc/vm/vm.h index 6a078c9baf..e905ec528b 100644 --- a/mindspore/ccsrc/vm/vm.h +++ b/mindspore/ccsrc/vm/vm.h @@ -51,15 +51,17 @@ enum Instruction { kPush, kPrim, kGraph, - kPadStack + kPadStack, + kSwitchLayer }; using InstType = std::pair; using InstSet = std::vector; using InstFunctionMap = std::map>; -const std::vector inst_str{"call", "tail_call", "return", "partial", "switch", "switch_return", "tuple", - "input", "external", "push", "primitive", "graph", "pad_stack"}; +const std::vector inst_str{"call", "tail_call", "return", "partial", "switch", + "switch_return", "tuple", "input", "external", "push", + "primitive", "graph", "pad_stack", "switch_layer"}; class StructPartial : public Base { public: // Initialize StructPartial. @@ -114,6 +116,7 @@ class FinalVM { void InstExternal(const VectorRef &args); void InstPushPrim(const VectorRef &args); void InstSwitchReturn(const VectorRef &args); + void InstSwitchLayer(const VectorRef &args); void set_insts(const InstSet &value) { insts_ = value; } BaseRef RunHook(const PrimitivePtr &prim, const VectorRef &arg); @@ -157,7 +160,7 @@ class FinalVM { {Instruction::kExternal, [this](const VectorRef &args) { InstExternal(args); }}, {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, - }; + {Instruction::kSwitchLayer, [this](const VectorRef &args) { InstSwitchLayer(args); }}}; std::map _hook_grad; };