diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index c85e6a72ce..d5228a9950 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -78,6 +78,10 @@ const PrimitivePtr kPrimEmbed = std::make_shared("embed"); const PrimitivePtr kPrimRefToEmbed = std::make_shared("RefToEmbed"); const PrimitivePtr kPrimCreateInstance = std::make_shared("create_instance"); +const PrimitivePtr kPrimLabelGoto = std::make_shared("LabelGoto"); +const PrimitivePtr kPrimLabelSwitch = std::make_shared("LabelSwitch"); +const PrimitivePtr kPrimLabelSet = std::make_shared("LabelSet"); + // Structure const PrimitivePtr kPrimStringEqual = std::make_shared("string_equal"); const PrimitivePtr kPrimStringConcat = std::make_shared("string_concat"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index d7394be310..d0dd33a1aa 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -84,6 +84,10 @@ extern const PrimitivePtr kPrimEmbed; extern const PrimitivePtr kPrimRefToEmbed; extern const PrimitivePtr kPrimCreateInstance; +extern const PrimitivePtr kPrimLabelGoto; +extern const PrimitivePtr kPrimLabelSwitch; +extern const PrimitivePtr kPrimLabelSet; + // Structure extern const PrimitivePtr kPrimStringEqual; extern const PrimitivePtr kPrimStringConcat; diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index b22e9c9993..5e43293983 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -269,13 +269,41 @@ bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePa bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); } +static bool IsCtrlSink() { + auto ms_ctx = MsContext::GetInstance(); + std::string device_target = ms_ctx->device_target(); + if (device_target != kAscendDevice) { + return false; + } + + if (!ms_ctx->enable_task_sink()) { + return false; + } + + char *enable_ctrl_sink = std::getenv("ENABLE_CTRL_SINK"); + if (enable_ctrl_sink == nullptr) { + return false; + } + std::string enable_ctrl_sink_str(enable_ctrl_sink); + if (enable_ctrl_sink_str == "0") { + return false; + } + + return true; +} + bool TaskEmitAction(const ResourcePtr &res) { if (res->func_graph() == nullptr) { MS_LOG(EXCEPTION) << "TaskEmit args error"; } FuncGraphPtr func_graph = res->func_graph(); - auto bc_ptr = res->results()[kBackend].cast(); + + if (IsCtrlSink()) { + res->results()[kOutput] = bc_ptr->CompileGraph(NOT_NULL(func_graph)); + return true; + } + std::vector cut_list = compile::nonlinear_ops; if (bc_ptr->name() == kMsConvert) { cut_list = compile::GetMsNonlinearOps(); @@ -286,10 +314,31 @@ bool TaskEmitAction(const ResourcePtr &res) { } bool ExecuteAction(const ResourcePtr &res) { - if (res->results().count(kOutput) == 0 || !res->results()[kOutput].is()) { + if (res->results().count(kOutput) == 0) { MS_LOG(EXCEPTION) << "Execute args error"; } + if (IsCtrlSink()) { + if (!res->results()[kOutput].is()) { + MS_LOG(EXCEPTION) << "Execute args error"; + } + + auto graph_id = res->results()[kOutput].cast(); + auto bc_ptr = res->results()[kBackend].cast>(); + compile::VmEvalFuncPtr run = + std::make_shared([&bc_ptr, graph_id](const VectorRef &args) -> BaseRef { + MS_LOG(INFO) << "Execute args size" << args.size(); + auto outs = bc_ptr->RunGraph(graph_id, args); + MS_LOG(DEBUG) << "out size" << outs.size(); + return outs[0]; + }); + res->results()[kOutput] = run; + return true; + } + + if (!res->results()[kOutput].is()) { + MS_LOG(EXCEPTION) << "Execute args error"; + } compile::FinalVMPtr vm = res->results()[kOutput].cast(); if (vm == nullptr) { MS_LOG(INFO) << "Call GE to Run the func_graph instead of VM"; diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 4fb46b604d..5a6dc792c8 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -138,7 +138,7 @@ GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrL return graph_id; } -GraphId AscendSession::CompileGraph(const FuncGraphPtr &func_graph) { +GraphId AscendSession::CompileGraph(NotNull func_graph) { MS_LOG(INFO) << "start"; auto graph = ConstructKernelGraph(func_graph); // split switch diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index 4823d292a4..45662250d3 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -42,7 +42,7 @@ class AscendSession : public SessionBasic { context_ = std::make_shared(kAscendDevice, device_id); } GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; - GraphId CompileGraph(const FuncGraphPtr &func_graph) override; + GraphId CompileGraph(NotNull func_graph) override; void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; void BuildGraph(GraphId) override; void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h index 0d55b185f4..efd27743f7 100755 --- a/mindspore/ccsrc/session/session_basic.h +++ b/mindspore/ccsrc/session/session_basic.h @@ -28,6 +28,7 @@ #include "ir/meta_tensor.h" #include "utils/any.h" #include "utils/base_ref.h" +#include "utils/contract.h" #include "pynative/pynative_execute.h" #include "device/kernel_info.h" @@ -57,7 +58,7 @@ class SessionBasic { virtual ~SessionBasic() { summary_callback_ = nullptr; } virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; - virtual GraphId CompileGraph(const FuncGraphPtr &) { return kInvalidGraphId; } + virtual GraphId CompileGraph(NotNull func_graph) { return kInvalidGraphId; } // build graph, used to handle multiple child graphs virtual void BuildGraph(GraphId) {} diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index caf4eb3ee3..a9f526418f 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -327,5 +327,9 @@ MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_ sess_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); } +GraphId MsBackend::CompileGraph(NotNull fg) { return sess_->CompileGraph(fg); } + +VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return MsRunGraph(graph_id, args); } + } // namespace compile } // namespace mindspore diff --git a/mindspore/ccsrc/vm/backend.h b/mindspore/ccsrc/vm/backend.h index 769dab473e..94b7a500e2 100644 --- a/mindspore/ccsrc/vm/backend.h +++ b/mindspore/ccsrc/vm/backend.h @@ -22,6 +22,7 @@ #include #include +#include "utils/contract.h" #include "ir/anf.h" #include "vm/segment_runner.h" #include "vm/vm.h" @@ -49,7 +50,7 @@ class Backend { virtual void SetSwitchActive(const BaseRef &, bool) {} virtual void RecallGraphInput(const FuncGraphPtr &, const VectorRef &, const BaseRef &) {} virtual void SetGraphUserInputs(const FuncGraphPtr &, const FuncGraphPtr &, const AnfNodePtrList &) {} - + virtual GraphId CompileGraph(NotNull fg) { return kInvalidGraphId; } void set_curr_switch(const BaseRef &value) { curr_switch_ = value; is_switch_call_ = true; @@ -104,6 +105,8 @@ class MsBackend : public Backend { void Link(GraphId) override; AnfNodePtr ConvertGraphInput(const FuncGraphPtr &, const AnfNodePtr &); LinConvertResult GetMultiGraphRun(const FuncGraphPtr &g) override; + GraphId CompileGraph(NotNull fg) override; + VectorRef RunGraph(GraphId graph_id, const VectorRef &args); private: session::SessionPtr sess_;