From b88e722e1299e3fb5a19bb0a626881fcfef8c8ca Mon Sep 17 00:00:00 2001 From: kswang Date: Tue, 10 Nov 2020 19:51:19 +0800 Subject: [PATCH] add graph partition --- mindspore/ccsrc/backend/session/executor.cc | 8 +- mindspore/ccsrc/backend/session/executor.h | 4 +- .../ccsrc/backend/session/session_basic.cc | 4 +- .../ccsrc/backend/session/session_basic.h | 6 +- mindspore/ccsrc/pipeline/jit/action.cc | 2 +- mindspore/ccsrc/pipeline/jit/pipeline.cc | 3 +- mindspore/ccsrc/vm/backend.cc | 29 +- mindspore/ccsrc/vm/backend.h | 3 +- mindspore/ccsrc/vm/graph_partition.cc | 447 ++++++++++++++++++ mindspore/ccsrc/vm/graph_partition.h | 48 ++ mindspore/ccsrc/vm/segment_runner.cc | 16 +- mindspore/ccsrc/vm/segment_runner.h | 12 +- mindspore/ccsrc/vm/transform.cc | 430 +---------------- mindspore/ccsrc/vm/transform.h | 13 +- mindspore/core/ir/anf.cc | 16 + mindspore/core/ir/anf.h | 7 + mindspore/core/ir/func_graph.cc | 13 + mindspore/core/ir/func_graph.h | 1 + tests/ut/cpp/vm/segment_runner_test.cc | 48 +- 19 files changed, 605 insertions(+), 505 deletions(-) create mode 100644 mindspore/ccsrc/vm/graph_partition.cc create mode 100644 mindspore/ccsrc/vm/graph_partition.h diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index 6752c3a4ec..c732513870 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -89,7 +89,8 @@ bool TensorInVector(const VectorRef *outputs) { void CompileNodesTask::Run() { MS_EXCEPTION_IF_NULL(session_); - graph_id_ = session_->CompileGraphImpl(nodes_, output_nodes_); + MS_EXCEPTION_IF_NULL(segment_); + graph_id_ = session_->CompileGraphImpl(segment_->nodes_, output_nodes_); } void CompileGraphTask::Run() { @@ -226,10 +227,11 @@ void Executor::SyncRunTask(const std::shared_ptr &task) { MsException::GetInstance().CheckException(); } -GraphId Executor::CompileGraph(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { +GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, + const AnfNodePtrList &outputs) { auto task = std::make_shared(); task->session_ = session; - task->nodes_ = lst; + task->segment_ = segment; task->output_nodes_ = outputs; SyncRunTask(task); return task->graph_id_; diff --git a/mindspore/ccsrc/backend/session/executor.h b/mindspore/ccsrc/backend/session/executor.h index c026177592..d57193f8c6 100644 --- a/mindspore/ccsrc/backend/session/executor.h +++ b/mindspore/ccsrc/backend/session/executor.h @@ -63,7 +63,7 @@ class CompileNodesTask : public Task { CompileNodesTask() { type_ = kCompileNodes; } ~CompileNodesTask() override = default; void Run() override; - AnfNodePtrList nodes_; + GraphSegmentPtr segment_; AnfNodePtrList output_nodes_; GraphId graph_id_{0}; }; @@ -151,7 +151,7 @@ class Executor { ~Executor(); void WorkerLoop(); void WorkerJoin(); - GraphId CompileGraph(const SessionPtr &session, const AnfNodePtrList &lst, const AnfNodePtrList &outputs); + GraphId CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment, const AnfNodePtrList &outputs); GraphId CompileGraph(const SessionPtr &session, NotNull func_graph); void BuildGraph(const SessionPtr &session, GraphId graphId); void RunGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector &inputs, diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 1048f15246..ae70504152 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -1388,9 +1388,9 @@ AnfNodePtr SessionBasic::FindPullNode(const AnfNodePtr &push_node, const std::ve return nullptr; } -GraphId SessionBasic::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { +GraphId SessionBasic::CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs) { MS_EXCEPTION_IF_NULL(executor_); - return executor_->CompileGraph(shared_from_this(), lst, outputs); + return executor_->CompileGraph(shared_from_this(), segment, outputs); } GraphId SessionBasic::CompileGraph(NotNull func_graph) { diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 3f0c2d80b1..915a24da2f 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -68,7 +68,7 @@ class SessionBasic : public std::enable_shared_from_this { virtual ~SessionBasic() { summary_callback_ = nullptr; } - GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs); + GraphId CompileGraph(const GraphSegmentPtr &segment, const AnfNodePtrList &outputs); GraphId CompileGraph(NotNull func_graph); void BuildGraph(GraphId graphId); void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs); @@ -102,6 +102,8 @@ class SessionBasic : public std::enable_shared_from_this { virtual void GetModelInputsInfo(uint32_t graph_id, std::vector *inputs) const {} std::vector GetInputNeedLockTensors(const GraphId &graph_id, const std::vector &inputs); + // Get graph by graph id, if not exist return null ptr + KernelGraphPtr GetGraph(GraphId graph_id) const; #ifdef ENABLE_DEBUGGER // set debugger void SetDebugger() { @@ -147,8 +149,6 @@ class SessionBasic : public std::enable_shared_from_this { virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, VectorRef *outputs) {} void RunInfer(NotNull func_graph, const std::vector &inputs); - // Get graph by graph id ,if not exist return null ptr - KernelGraphPtr GetGraph(GraphId graph_id) const; virtual void SetSummaryNodes(KernelGraph *graph); diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 1f414d7d73..0c97556a8a 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -354,7 +354,7 @@ bool TaskEmitAction(const ResourcePtr &res) { auto context_ptr = MsContext::GetInstance(); std::string backend = MsContext::GetInstance()->backend_policy(); MS_EXCEPTION_IF_NULL(context_ptr); - if (CompileGraphs::ContainMixedTarget(func_graph)) { + if (func_graph->ContainMultiTarget()) { bc_ptr->set_is_multi_graph_sink(false); context_ptr->set_param(MS_CTX_IS_MULTI_GRAPH_SINK, false); context_ptr->set_param(MS_CTX_ENABLE_LOOP_SINK, false); diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 5d65a2ff57..1d0e3664ee 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -923,7 +923,8 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc MS_EXCEPTION_IF_NULL(convert_fn); // Convert CNodeList to LinConvertResult. ConfigManager::GetInstance().set_iter_num(1); - auto runner = convert_fn({app_init}, ""); + auto segment = std::make_shared(std::vector{app_init}, false); + auto runner = convert_fn(segment, ""); if (MsContext::GetInstance()->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { backend->Link(runner.graph_id); } diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 4dd15394b4..3ea63e41fc 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -34,30 +34,34 @@ namespace compile { bool Backend::GetCond(const BaseRef &c, bool *const value) { return BaseRefToBool(c, value); } bool Backend::GetIndex(const BaseRef &c, int64_t *const value) { return BaseRefToInt(utils::cast(c), value); } -LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::string &target) { +Backend::Backend(const std::string &name) : name_(name) { + MS_LOG(DEBUG) << "select backend:" << name; + convert_fn_ = MsVmConvert; + is_multi_graph_sink_ = false; +} + +LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std::string &target) { MS_LOG(DEBUG) << "MsConvert"; + MS_EXCEPTION_IF_NULL(segment); MS_EXCEPTION_IF_NULL(MsContext::GetInstance()); - auto cached = g_ConvertCache.find(lst); + auto cached = g_ConvertCache.find(segment); if (cached != g_ConvertCache.end()) { return cached->second; } - LinConvertResult result; - FuncGraphPtr fg; AnfNodePtrList inputs; AnfNodePtrList outputs; - - std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(lst); + std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_); result.inputs = inputs; result.outputs = outputs; result.graph_id = kInvalidGraphId; GraphId graph_id = kInvalidGraphId; if (target != target_device_ && !target.empty()) { CreateOtherSession(target); - graph_id = other_sess_->CompileGraph(lst, outputs); + graph_id = other_sess_->CompileGraph(segment, outputs); } else { - graph_id = target_sess_->CompileGraph(lst, outputs); + graph_id = target_sess_->CompileGraph(segment, outputs); } if (MsContext::GetInstance()->get_param(MS_CTX_PRECOMPILE_ONLY)) { @@ -79,7 +83,7 @@ LinConvertResult MsBackend::MsConvert(const AnfNodePtrList &lst, const std::stri result.graph_id = graph_id; graph_id_map_[graph_id] = result; - (void)g_ConvertCache.emplace(lst, result); + (void)g_ConvertCache.emplace(segment, result); return result; } @@ -154,12 +158,6 @@ void MsBackend::Link(GraphId graph_id) { target_sess_->BuildGraph(graph_id); } -Backend::Backend(const std::string &name) : name_(name) { - MS_LOG(DEBUG) << "select backend:" << name; - convert_fn_ = backends[name_]; - is_multi_graph_sink_ = false; -} - MsBackend::MsBackend(const std::string &name, const std::string &target, uint32_t device_id) : Backend(name) { convert_fn_ = std::bind(&MsBackend::MsConvert, this, std::placeholders::_1, std::placeholders::_2); target_sess_ = session::SessionFactory::Get().Create(target); @@ -194,6 +192,5 @@ VectorRef MsBackend::RunGraph(GraphId graph_id, const VectorRef &args) { return #ifdef ENABLE_DEBUGGER void MsBackend::SetDebugger() { target_sess_->SetDebugger(); } #endif - } // namespace compile } // namespace mindspore diff --git a/mindspore/ccsrc/vm/backend.h b/mindspore/ccsrc/vm/backend.h index 2ed9ff36ad..c7f1208057 100644 --- a/mindspore/ccsrc/vm/backend.h +++ b/mindspore/ccsrc/vm/backend.h @@ -25,6 +25,7 @@ #include "utils/contract.h" #include "ir/anf.h" #include "vm/segment_runner.h" +#include "vm/graph_partition.h" #include "vm/vm.h" #include "backend/session/session_basic.h" @@ -63,7 +64,7 @@ class MsBackend : public Backend { MsBackend(const std::string &name, const std::string &target, uint32_t device_id); ~MsBackend() override = default; - LinConvertResult MsConvert(const AnfNodePtrList &lst, const std::string &target = ""); + LinConvertResult MsConvert(const GraphSegmentPtr &segment, const std::string &target = ""); VectorRef MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target = ""); VectorRef MsSimuRunGraph(const GraphId &g, const VectorRef &args); diff --git a/mindspore/ccsrc/vm/graph_partition.cc b/mindspore/ccsrc/vm/graph_partition.cc new file mode 100644 index 0000000000..5547e5a5e4 --- /dev/null +++ b/mindspore/ccsrc/vm/graph_partition.cc @@ -0,0 +1,447 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "vm/graph_partition.h" +#include +#include +#include +#include +#include +#include +#include +#include "base/core_ops.h" +#include "utils/utils.h" +#include "utils/ms_context.h" +namespace mindspore { +const char kMsConvert[] = "ms"; +const char kMsVm[] = "vm"; +const char kGeVm[] = "ge"; +namespace compile { +namespace { +bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node, + std::vector *prior_nodes, std::vector *depend_nodes) { + MS_EXCEPTION_IF_NULL(prior_node); + MS_EXCEPTION_IF_NULL(behind_node); + MS_EXCEPTION_IF_NULL(graph); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto &node_users = manager->node_users(); + if (prior_node->isa()) { + for (auto &user : node_users[prior_node]) { + auto cnode = user.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { + prior_nodes->emplace_back(cnode); + } + } + } else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) { + prior_nodes->emplace_back(prior_node); + } else { + return false; + } + if (behind_node->isa()) { + for (auto &user : node_users[behind_node]) { + auto cnode = user.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { + depend_nodes->emplace_back(cnode); + } + } + } else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) { + depend_nodes->emplace_back(behind_node); + } else { + return false; + } + return true; +} + +void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node, + std::map> *control_edges, + std::map *nodes_ref) { + MS_EXCEPTION_IF_NULL(node); + auto input_cnode = node->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + auto prior_node = input_cnode->input(kControlDependPriorIndex); + auto depend_node = input_cnode->input(kControlDependBehindIndex); + MS_EXCEPTION_IF_NULL(prior_node); + MS_EXCEPTION_IF_NULL(depend_node); + PrimitivePtr prim_ptr = GetValueNode(input_cnode->input(0)); + MS_EXCEPTION_IF_NULL(prim_ptr); + ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); + int64_t depend_mode = 0; + if (mode_ptr != nullptr) { + depend_mode = GetValue(mode_ptr); + } + if ((prior_node->isa() || depend_node->isa()) && depend_mode == 0) { + return; + } + std::vector prior_nodes; + std::vector behind_nodes; + if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) { + return; + } + for (auto &first_node : prior_nodes) { + for (auto &second_node : behind_nodes) { + MS_EXCEPTION_IF_NULL(first_node); + MS_EXCEPTION_IF_NULL(second_node); + auto iter = control_edges->find(second_node); + if (iter == control_edges->end()) { + (void)control_edges->insert( + std::pair>(second_node, std::vector{first_node})); + } else { + iter->second.emplace_back(first_node); + } + auto ref_iter = nodes_ref->find(first_node); + if (ref_iter != nodes_ref->end()) { + ref_iter->second++; + } else { + (void)nodes_ref->insert(std::pair(first_node, 1)); + } + } + } +} + +void CalcNodeRefCount(const FuncGraphPtr &graph, std::map *nodes_ref, + std::map> *control_edges) { + std::queue queue; + queue.push(graph->get_return()); + std::set visited; + while (!queue.empty()) { + auto &node = queue.front(); + queue.pop(); + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + for (auto &input : cnode->inputs()) { + if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) { + AddControlEdge(graph, input, control_edges, nodes_ref); + } + auto iter = nodes_ref->find(input); + if (iter != nodes_ref->end()) { + iter->second++; + } else { + (void)nodes_ref->insert(std::pair(input, 1)); + } + if (visited.find(input) != visited.end()) { + continue; + } + visited.insert(input); + queue.push(input); + } + } +} + +std::vector OptimizeGetItemOrder(const std::vector &nodes) { + std::vector result; + std::map> insert_positions; + std::map node_positions; + for (auto &node : nodes) { + if (node->isa() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto &inputs = cnode->inputs(); + if (inputs.size() < 2) { + MS_LOG(EXCEPTION) << "Invalid get item node"; + } + auto &parent = inputs[1]; + auto iter = node_positions.find(parent); + if (iter != node_positions.end()) { + size_t position = iter->second; + auto iter_nodes = insert_positions.find(position); + if (iter_nodes != insert_positions.end()) { + iter_nodes->second.push_back(node); + } else { + (void)insert_positions.insert( + std::pair>(position, std::vector{node})); + } + continue; + } + } + result.emplace_back(node); + node_positions[node] = result.size(); + } + + size_t insert_num = 0; + for (auto &item : insert_positions) { + size_t position = item.first + insert_num; + (void)result.insert(result.begin() + position, item.second.begin(), item.second.end()); + insert_num += item.second.size(); + } + return result; +} + +std::vector SplitSort(const FuncGraphPtr &graph, const std::string &default_target) { + std::vector result; + std::stack to_visit; + std::stack next_to_visit; + std::map nodes_ref; + std::map> control_edges; + CalcNodeRefCount(graph, &nodes_ref, &control_edges); + std::string handle_target = default_target; + std::string next_target = ""; + to_visit.push(graph->get_return()); + while (!to_visit.empty() || !next_to_visit.empty()) { + if (to_visit.empty()) { + to_visit.swap(next_to_visit); + handle_target = next_target; + } + auto &node = to_visit.top(); + MS_EXCEPTION_IF_NULL(node); + to_visit.pop(); + result.emplace_back(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto node_inputs = cnode->inputs(); + std::reverse(node_inputs.begin(), node_inputs.end()); + auto ctrl_inputs = control_edges.find(node); + if (ctrl_inputs != control_edges.end()) { + node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); + } + for (auto &input : node_inputs) { + auto iter = nodes_ref.find(input); + if (iter != nodes_ref.end()) { + iter->second--; + if (iter->second != 0) { + continue; + } + } + if (!input->isa()) { + to_visit.push(input); + continue; + } + std::string input_target = GetCNodeTarget(input); + if (input_target == handle_target) { + to_visit.push(input); + } else if (next_to_visit.empty() || input_target == next_target) { + next_to_visit.push(input); + next_target = input_target; + } else { + MS_LOG(EXCEPTION) << "Only support two different target"; + } + } + } + std::reverse(result.begin(), result.end()); + return result; +} + +std::vector ParallelSplitSort(const FuncGraphPtr &graph, const std::string &default_target) { + std::vector result; + std::stack handle_nodes; + std::stack next_handle_nodes; + std::stack wait_handle_nodes; + std::map nodes_ref; + std::map> control_edges; + CalcNodeRefCount(graph, &nodes_ref, &control_edges); + std::string handle_target = default_target; + std::string next_target = ""; + handle_nodes.push(graph->get_return()); + while (!handle_nodes.empty() || !next_handle_nodes.empty() || !wait_handle_nodes.empty()) { + if (handle_nodes.empty()) { + handle_nodes.swap(next_handle_nodes); + handle_target.swap(next_target); + next_handle_nodes.swap(wait_handle_nodes); + } + auto &node = handle_nodes.top(); + MS_EXCEPTION_IF_NULL(node); + handle_nodes.pop(); + result.emplace_back(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto node_inputs = cnode->inputs(); + std::reverse(node_inputs.begin(), node_inputs.end()); + auto ctrl_inputs = control_edges.find(node); + if (ctrl_inputs != control_edges.end()) { + node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); + } + std::vector same_target_ready_inputs; + std::vector diff_target_ready_inputs; + for (auto &input : node_inputs) { + auto iter = nodes_ref.find(input); + if (iter != nodes_ref.end()) { + iter->second--; + if (iter->second != 0) { + continue; + } + } + if (!input->isa()) { + handle_nodes.push(input); + continue; + } + std::string input_target = GetCNodeTarget(input); + if (input_target == handle_target) { + same_target_ready_inputs.emplace_back(input); + } else { + if (next_target.empty()) { + next_target = input_target; + } + if (input_target != next_target) { + MS_LOG(EXCEPTION) << "Only support two different target"; + } + diff_target_ready_inputs.emplace_back(input); + } + } + if (diff_target_ready_inputs.size() == 0) { + for (auto &input : same_target_ready_inputs) { + handle_nodes.push(input); + } + } else { + for (auto &input : same_target_ready_inputs) { + wait_handle_nodes.push(input); + } + for (auto &input : diff_target_ready_inputs) { + next_handle_nodes.push(input); + } + } + } + std::reverse(result.begin(), result.end()); + return result; +} + +bool IsSubGraph(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + if (inputs.empty()) { + MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; + } + + AnfNodePtr fn = inputs[0]; + if (!IsValueNode(fn)) { + return false; + } + auto node_prim = GetValueNode(fn); + if (node_prim->name() == prim::kPrimPartial->name()) { + return true; + } + } else if (IsValueNode(node)) { + return true; + } + return false; +} +} // namespace + +GraphPartition::GraphPartition(const std::vector &cut_list, const std::string &backend_name) + : cut_list_(cut_list), backend_name_(backend_name) {} + +bool GraphPartition::IsCut(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + auto cnode = node->cast(); + auto &inputs = cnode->inputs(); + if (inputs.empty()) { + MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; + } + AnfNodePtr fn = inputs[0]; + if (IsValueNode(fn)) { + auto fg = GetValueNode(fn); + if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { + return false; + } + } + if (!IsValueNode(fn)) { + return true; + } + PrimitivePtr node_prim = GetValueNode(fn); + for (auto &prim : cut_list_) { + MS_EXCEPTION_IF_NULL(prim); + if (prim->name() == node_prim->name()) { + if (prim->name() == prim::kPrimBpropCut->name()) { + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_HOOK, true); + } + if (backend_name_ == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) { + if (inputs.size() < 2) { + return false; + } + auto ret = IsSubGraph(inputs[1]); + return ret; + } + return true; + } + } +#ifdef ENABLE_GE + if (backend_name_ == kGeVm) { + auto name = GetCNodeFuncName(cnode); + auto adpt = transform::DfGraphConvertor::FindAdapter(name); + if (adpt == nullptr) { + return true; + } + } +#endif + } + return false; +} + +std::vector GraphPartition::Partition(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + auto nodes = TopoSort(graph->get_return()); + MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); + bool contain_multi_target = ContainMultiTarget(nodes); + if (contain_multi_target) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + std::string default_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); + if (graph != nullptr) { + nodes = SplitSort(graph, default_target); + } else { + nodes = ParallelSplitSort(graph, default_target); + } + nodes = OptimizeGetItemOrder(nodes); + } + std::vector segments; + std::vector segment_nodes; + std::string last_target; + for (auto &node : nodes) { + MS_EXCEPTION_IF_NULL(node); + if (IsCut(node)) { + if (segment_nodes.size() != 0) { + auto segment = std::make_shared(segment_nodes, false); + segments.emplace_back(segment); + segment_nodes.clear(); + } + segment_nodes.emplace_back(node); + auto segment = std::make_shared(segment_nodes, true); + segments.push_back(segment); + segment_nodes.clear(); + } else if (node->isa()) { + if (contain_multi_target) { + std::string cur_target = GetCNodeTarget(node); + if (cur_target != last_target && !last_target.empty() && segment_nodes.size() != 0) { + auto segment = std::make_shared(segment_nodes, false); + segments.emplace_back(segment); + segment_nodes.clear(); + } + last_target = cur_target; + } + segment_nodes.emplace_back(node); + } + } + MS_LOG(DEBUG) << "Segment size:" << segments.size(); + return segments; +} +} // namespace compile +} // namespace mindspore diff --git a/mindspore/ccsrc/vm/graph_partition.h b/mindspore/ccsrc/vm/graph_partition.h new file mode 100644 index 0000000000..d021fa0b1c --- /dev/null +++ b/mindspore/ccsrc/vm/graph_partition.h @@ -0,0 +1,48 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_ +#define MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_ +#include +#include +#include +#include "ir/anf.h" +#include "ir/func_graph.h" +#include "ir/graph_utils.h" +#include "base/base_ref.h" + +namespace mindspore { +extern const char kMsVm[]; +extern const char kGeVm[]; +extern const char kMsConvert[]; + +namespace compile { +class GraphPartition { + public: + explicit GraphPartition(const std::vector &cut_list, const std::string &backend_name); + ~GraphPartition() = default; + std::vector Partition(const FuncGraphPtr &func_graph); + + private: + bool IsCut(const AnfNodePtr &node); + std::vector cut_list_; + std::string backend_name_; +}; + +using GraphPartitionPtr = std::shared_ptr; +} // namespace compile +} // namespace mindspore +#endif // MINDSPORE_CCSRC_VM_GRAPH_PARTITION_H_ diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index a6b13e196d..4e304d2ea8 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -34,10 +34,6 @@ #include "frontend/operator/ops.h" namespace mindspore { -const char kMsConvert[] = "ms"; -const char kMsVm[] = "vm"; -const char kGeVm[] = "ge"; - namespace compile { // cached conversion ConvertCache g_ConvertCache; @@ -194,8 +190,9 @@ std::tuple TransformSegmentToAnfGr // This implementation will convert the nodes into a subgraph // that will run using the MsVM. template -LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) { - auto cached = g_ConvertCache.find(lst); +LinConvertResult Convert(const GraphSegmentPtr &segment, const std::string &) { + MS_EXCEPTION_IF_NULL(segment); + auto cached = g_ConvertCache.find(segment); if (cached != g_ConvertCache.end()) { return cached->second; } @@ -206,7 +203,7 @@ LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) { AnfNodePtrList inputs; AnfNodePtrList outputs; - std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(lst); + std::tie(fg, inputs, outputs) = TransformSegmentToAnfGraph(segment->nodes_); // Clone in case g contains subgraphs that have a different manager fg = BasicClone(fg); @@ -219,18 +216,15 @@ LinConvertResult Convert(const AnfNodePtrList &lst, const std::string &) { result.outputs = outputs; result.graph_id = UINT32_MAX; - (void)g_ConvertCache.emplace(lst, result); + (void)g_ConvertCache.emplace(segment, result); return result; } LinkFuncType MsVmConvert = Convert; -std::unordered_map backends = {{kMsVm, MsVmConvert}}; - std::set backend_list = { kMsConvert, kMsVm, }; - } // namespace compile } // namespace mindspore diff --git a/mindspore/ccsrc/vm/segment_runner.h b/mindspore/ccsrc/vm/segment_runner.h index c4458d4148..316fed57ca 100644 --- a/mindspore/ccsrc/vm/segment_runner.h +++ b/mindspore/ccsrc/vm/segment_runner.h @@ -27,14 +27,10 @@ #include "ir/anf.h" #include "vm/vmimpl.h" +#include "vm/graph_partition.h" namespace mindspore { -extern const char kMsVm[]; -extern const char kGeVm[]; -extern const char kMsConvert[]; - namespace compile { - struct LinConvertResult { RunFuncPtr run; RunFuncPtr simu_run; @@ -43,11 +39,9 @@ struct LinConvertResult { uint32_t graph_id; }; -using LinkFuncType = std::function; -using ConvertCache = std::unordered_map; +using LinkFuncType = std::function; +using ConvertCache = std::unordered_map; extern LinkFuncType MsVmConvert; -extern LinkFuncType GeVmConvert; -extern std::unordered_map backends; extern ConvertCache g_ConvertCache; extern std::set backend_list; diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 3b1906d22c..e6ad847777 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -21,8 +21,6 @@ #include #include #include -#include -#include #include #include @@ -52,386 +50,13 @@ const std::vector &GetMsNonlinearOps() { return ms_nonlinear_ops; } -namespace { -bool ContainMultiTarget(const std::vector &nodes) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - std::string last_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); - for (auto &node : nodes) { - if (node->isa()) { - std::string cur_target = GetCNodeTarget(node); - if (last_target != cur_target) { - return true; - } - last_target = cur_target; - } - } - return false; -} - -bool ExtractNodes(const FuncGraphPtr &graph, const AnfNodePtr &prior_node, const AnfNodePtr &behind_node, - std::vector *prior_nodes, std::vector *depend_nodes) { - MS_EXCEPTION_IF_NULL(prior_node); - MS_EXCEPTION_IF_NULL(behind_node); - MS_EXCEPTION_IF_NULL(graph); - auto manager = graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - auto &node_users = manager->node_users(); - if (prior_node->isa()) { - for (auto &user : node_users[prior_node]) { - auto cnode = user.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { - prior_nodes->emplace_back(cnode); - } - } - } else if (!IsPrimitiveCNode(prior_node, prim::kPrimControlDepend)) { - prior_nodes->emplace_back(prior_node); - } else { - return false; - } - if (behind_node->isa()) { - for (auto &user : node_users[behind_node]) { - auto cnode = user.first->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (!IsPrimitiveCNode(cnode, prim::kPrimControlDepend)) { - depend_nodes->emplace_back(cnode); - } - } - } else if (!IsPrimitiveCNode(behind_node, prim::kPrimControlDepend)) { - depend_nodes->emplace_back(behind_node); - } else { - return false; - } - return true; -} - -void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node, - std::map> *control_edges, - std::map *nodes_ref) { - MS_EXCEPTION_IF_NULL(node); - auto input_cnode = node->cast(); - MS_EXCEPTION_IF_NULL(input_cnode); - auto prior_node = input_cnode->input(kControlDependPriorIndex); - auto depend_node = input_cnode->input(kControlDependBehindIndex); - MS_EXCEPTION_IF_NULL(prior_node); - MS_EXCEPTION_IF_NULL(depend_node); - PrimitivePtr prim_ptr = GetValueNode(input_cnode->input(0)); - MS_EXCEPTION_IF_NULL(prim_ptr); - ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); - int64_t depend_mode = 0; - if (mode_ptr != nullptr) { - depend_mode = GetValue(mode_ptr); - } - if ((prior_node->isa() || depend_node->isa()) && depend_mode == 0) { - return; - } - std::vector prior_nodes; - std::vector behind_nodes; - if (!ExtractNodes(graph, prior_node, depend_node, &prior_nodes, &behind_nodes)) { - return; - } - for (auto &first_node : prior_nodes) { - for (auto &second_node : behind_nodes) { - MS_EXCEPTION_IF_NULL(first_node); - MS_EXCEPTION_IF_NULL(second_node); - auto iter = control_edges->find(second_node); - if (iter == control_edges->end()) { - (void)control_edges->insert( - std::pair>(second_node, std::vector{first_node})); - } else { - iter->second.emplace_back(first_node); - } - auto ref_iter = nodes_ref->find(first_node); - if (ref_iter != nodes_ref->end()) { - ref_iter->second++; - } else { - (void)nodes_ref->insert(std::pair(first_node, 1)); - } - } - } -} - -void CalcNodeRefCount(const FuncGraphPtr &graph, std::map *nodes_ref, - std::map> *control_edges) { - std::queue queue; - queue.push(graph->get_return()); - std::set visited; - while (!queue.empty()) { - auto &node = queue.front(); - queue.pop(); - MS_EXCEPTION_IF_NULL(node); - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - for (auto &input : cnode->inputs()) { - if (IsPrimitiveCNode(input, prim::kPrimControlDepend)) { - AddControlEdge(graph, input, control_edges, nodes_ref); - } - auto iter = nodes_ref->find(input); - if (iter != nodes_ref->end()) { - iter->second++; - } else { - (void)nodes_ref->insert(std::pair(input, 1)); - } - if (visited.find(input) != visited.end()) { - continue; - } - visited.insert(input); - queue.push(input); - } - } -} - -std::vector OptimizeGetItemOrder(const std::vector &nodes) { - std::vector result; - std::map> insert_positions; - std::map node_positions; - for (auto &node : nodes) { - if (node->isa() && IsPrimitiveCNode(node, prim::kPrimTupleGetItem)) { - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto &inputs = cnode->inputs(); - if (inputs.size() < 2) { - MS_LOG(EXCEPTION) << "Invalid get item node"; - } - auto &parent = inputs[1]; - auto iter = node_positions.find(parent); - if (iter != node_positions.end()) { - size_t position = iter->second; - auto iter_nodes = insert_positions.find(position); - if (iter_nodes != insert_positions.end()) { - iter_nodes->second.push_back(node); - } else { - (void)insert_positions.insert( - std::pair>(position, std::vector{node})); - } - continue; - } - } - result.emplace_back(node); - node_positions[node] = result.size(); - } - - size_t insert_num = 0; - for (auto &item : insert_positions) { - size_t position = item.first + insert_num; - (void)result.insert(result.begin() + position, item.second.begin(), item.second.end()); - insert_num += item.second.size(); - } - return result; -} - -std::vector SplitSort(const FuncGraphPtr &graph, const std::string &default_target) { - std::vector result; - std::stack to_visit; - std::stack next_to_visit; - std::map nodes_ref; - std::map> control_edges; - CalcNodeRefCount(graph, &nodes_ref, &control_edges); - std::string handle_target = default_target; - std::string next_target = ""; - to_visit.push(graph->get_return()); - while (!to_visit.empty() || !next_to_visit.empty()) { - if (to_visit.empty()) { - to_visit.swap(next_to_visit); - handle_target = next_target; - } - auto &node = to_visit.top(); - MS_EXCEPTION_IF_NULL(node); - to_visit.pop(); - result.emplace_back(node); - if (!node->isa()) { - continue; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto node_inputs = cnode->inputs(); - std::reverse(node_inputs.begin(), node_inputs.end()); - auto ctrl_inputs = control_edges.find(node); - if (ctrl_inputs != control_edges.end()) { - node_inputs.insert(node_inputs.end(), ctrl_inputs->second.begin(), ctrl_inputs->second.end()); - } - for (auto &input : node_inputs) { - auto iter = nodes_ref.find(input); - if (iter != nodes_ref.end()) { - iter->second--; - if (iter->second != 0) { - continue; - } - } - if (!input->isa()) { - to_visit.push(input); - continue; - } - std::string input_target = GetCNodeTarget(input); - if (input_target == handle_target) { - to_visit.push(input); - } else if (next_to_visit.empty() || input_target == next_target) { - next_to_visit.push(input); - next_target = input_target; - } else { - MS_LOG(EXCEPTION) << "only support two different target"; - } - } - } - std::reverse(result.begin(), result.end()); - return result; -} - -bool IsSubGraph(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - auto cnode = node->cast(); - auto &inputs = cnode->inputs(); - if (inputs.empty()) { - MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; - } - - AnfNodePtr fn = inputs[0]; - if (!IsValueNode(fn)) { - return false; - } - auto node_prim = GetValueNode(fn); - if (node_prim->name() == prim::kPrimPartial->name()) { - return true; - } - } else if (IsValueNode(node)) { - return true; - } - return false; -} -} // namespace - -CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector &cut_list) - : backend_(backend), cut_list_(cut_list) { +CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector &cut_list) : backend_(backend) { MS_EXCEPTION_IF_NULL(backend_); lin_convert_ = backend_->convert_fn(); if (lin_convert_ == nullptr) { MS_LOG(EXCEPTION) << "Attribute 'lin_convert' is null.: " << backend->name(); } - - is_gevm_convert_ = false; - if (backend->name() == kGeVm) { - MS_LOG(INFO) << "Attribute 'is_gevm_convert' is true"; - is_gevm_convert_ = true; - } -} - -bool CompileGraph::IsCut(const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(node); - if (node->isa()) { - auto cnode = node->cast(); - auto &inputs = cnode->inputs(); - if (inputs.empty()) { - MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; - } - - AnfNodePtr fn = inputs[0]; - if (IsValueNode(fn)) { - auto fg = GetValueNode(fn); - if (fg->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) { - return false; - } - } - - if (!IsValueNode(fn)) { - return true; - } - - PrimitivePtr node_prim = GetValueNode(fn); - for (auto &prim : cut_list_) { - MS_EXCEPTION_IF_NULL(prim); - if (prim->name() == node_prim->name()) { - if (prim->name() == prim::kPrimBpropCut->name()) { - auto ms_context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(ms_context); - ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_HOOK, true); - } - - if (backend_->name() == kMsConvert && prim->name() == prim::kPrimMakeTuple->name()) { - if (inputs.size() < 2) { - return false; - } - auto ret = IsSubGraph(inputs[1]); - return ret; - } - - return true; - } - } - -#ifdef ENABLE_GE - if (is_gevm_convert_) { - auto name = GetCNodeFuncName(cnode); - auto adpt = transform::DfGraphConvertor::FindAdapter(name); - if (adpt == nullptr) { - return true; - } - } -#endif - } - - return false; -} - -VectorRef CompileGraph::SplitNodesWithTarget(const std::vector &input_nodes, const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - auto nodes = OptimizeGetItemOrder(input_nodes); - VectorRef splits; - VectorRef split; - std::string last_target; - for (auto &node : nodes) { - MS_EXCEPTION_IF_NULL(node); - if (IsCut(node)) { - if (split.size() != 0) { - splits.push_back(split); - } - splits.push_back(node); - split.clear(); - } else if (node->isa()) { - std::string cur_target = GetCNodeTarget(node); - if (cur_target != last_target && !last_target.empty() && split.size() != 0) { - splits.push_back(split); - split.clear(); - } - last_target = cur_target; - split.push_back(node); - } - } - return splits; -} - -VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - auto nodes = TopoSort(graph->get_return()); - MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); - - if (ContainMultiTarget(nodes)) { - auto context_ptr = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context_ptr); - std::string default_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); - nodes = SplitSort(graph, default_target); - return SplitNodesWithTarget(nodes, graph); - } - - VectorRef splits; - VectorRef split; - for (auto &node : nodes) { - MS_EXCEPTION_IF_NULL(node); - if (IsCut(node)) { - if (split.size() != 0) { - splits.push_back(split); - } - splits.push_back(node); - split.clear(); - } else if (node->isa()) { - split.push_back(node); - } - } - return splits; + graph_partition_ = std::make_shared(cut_list, backend->name()); } // Push the value node on the stack. @@ -512,12 +137,12 @@ void CompileGraph::PushParameters(const FuncGraphPtr &graph) { } } -int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list, - const std::string &target) { +int64_t CompileGraph::LinConvert(const FuncGraphPtr &graph, const GraphSegmentPtr &segment, const std::string &target) { + MS_EXCEPTION_IF_NULL(segment); MS_LOG(DEBUG) << "LinConvert start"; LinConvertResult result; - result = lin_convert_(node_list, target); + result = lin_convert_(segment, target); if (result.run == nullptr) { MS_LOG(ERROR) << "LinConvert failed"; @@ -583,25 +208,23 @@ int64_t CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &n return RET_SUCCESS; } -bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { +bool CompileGraph::Compile(const FuncGraphPtr &graph) { MS_LOG(DEBUG) << "Start split graph"; MS_EXCEPTION_IF_NULL(graph); - VectorRef splits = SplitNodes(graph); + MS_EXCEPTION_IF_NULL(graph_partition_); + auto segments = graph_partition_->Partition(graph); - MS_LOG(DEBUG) << "Split nodes size:" << splits.size(); - for (auto &split : splits) { + MS_LOG(DEBUG) << "Split nodes size:" << segments.size(); + for (auto &segment : segments) { + MS_EXCEPTION_IF_NULL(segment); int64_t ret = RET_SUCCESS; - if (utils::isa(split)) { + if (!segment->is_cut_) { MS_LOG(DEBUG) << "Start a extern LinConvert"; - std::vector args; - auto vec_ref = utils::cast(split); - (void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args), - [](const BaseRef &v) { return utils::cast(v); }); - if (args.size() > 0) { - std::string cur_target = GetCNodeTarget(args[0]); - ret = LinConvert(graph, args, cur_target); + if (segment->nodes_.size() > 0) { + std::string cur_target = GetCNodeTarget(segment->nodes_[0]); + ret = LinConvert(graph, segment, cur_target); } else { - ret = LinConvert(graph, args); + ret = LinConvert(graph, segment); } MS_LOG(DEBUG) << "End a extern LinConvert"; if (ret == RET_FAILED) { @@ -612,10 +235,11 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { } } else { MS_LOG(DEBUG) << "Start a cut node"; - if (!(utils::isa(split) && utils::cast(split)->isa())) { + auto &cut_node = segment->nodes_[0]; + if (!cut_node->isa()) { MS_LOG(EXCEPTION) << "must be anfnode here NodeInfo: " << trace::GetDebugInfo(graph->debug_info()); } - CNodePtr node = utils::cast(split)->cast(); + CNodePtr node = cut_node->cast(); ret = InterpretNode(graph, node); MS_LOG(DEBUG) << "End a cut node"; if (ret == RET_BREAK) { @@ -635,7 +259,7 @@ InstSet CompileGraph::Run(const FuncGraphPtr &graph) { int64_t param_height = height_; MS_LOG(DEBUG) << "'param_height': " << height_ << " to split graph: " << graph->get_return()->DebugString(true); - if (!SplitGraph(graph)) { + if (!Compile(graph)) { return inst_; } @@ -897,20 +521,6 @@ FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) { return rt; } -bool CompileGraphs::ContainMixedTarget(const FuncGraphPtr &graph) { - MS_EXCEPTION_IF_NULL(graph); - auto graph_manager = graph->manager(); - MS_EXCEPTION_IF_NULL(graph_manager); - FuncGraphSet graphs = graph_manager->func_graphs(); - for (auto &g : graphs) { - auto nodes = TopoSort(g->get_return()); - if (ContainMultiTarget(nodes)) { - return true; - } - } - return false; -} - BackendPtr CreateBackend() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); diff --git a/mindspore/ccsrc/vm/transform.h b/mindspore/ccsrc/vm/transform.h index 89e36c6ff4..249ae94d58 100644 --- a/mindspore/ccsrc/vm/transform.h +++ b/mindspore/ccsrc/vm/transform.h @@ -31,6 +31,7 @@ #include "frontend/operator/ops.h" #include "vm/segment_runner.h" #include "vm/backend.h" +#include "vm/graph_partition.h" // mindspore namespace is the top level namespace of MindSpore project. // Other namespace should be a sub namespace of mindspore namespace in the ME project. @@ -59,7 +60,6 @@ class CompileGraph { void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; } void Ret(int64_t nargs); int64_t Ref(const AnfNodePtr &node); - VectorRef SplitNodes(const FuncGraphPtr &func_graph); void set_height(int64_t h) { height_ = h; @@ -76,10 +76,9 @@ class CompileGraph { } private: - VectorRef SplitNodesWithTarget(const std::vector &input_nodes, const FuncGraphPtr &graph); void PushParameters(const FuncGraphPtr &func_graph); - bool SplitGraph(const FuncGraphPtr &func_graph); - int64_t LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list, const std::string &target = ""); + bool Compile(const FuncGraphPtr &func_graph); + int64_t LinConvert(const FuncGraphPtr &func_graph, const GraphSegmentPtr &segment, const std::string &target = ""); int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node); void AddPadStack(int64_t param_height); @@ -97,11 +96,12 @@ class CompileGraph { void AddInst(const Instruction &inst, const VectorRef &args); BackendPtr backend_; + GraphPartitionPtr graph_partition_; LinkFuncType lin_convert_; - bool is_gevm_convert_; + int64_t height_{0}; int64_t max_height_{0}; - std::vector cut_list_; + std::unordered_map slots_; InstSet inst_; }; @@ -123,7 +123,6 @@ class CompileGraphs { void Compile(const FuncGraphPtr &func_graph); FinalVMPtr Link(const FuncGraphPtr &func_graph); FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph); - static bool ContainMixedTarget(const FuncGraphPtr &graph); private: InstSet insts_; diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index 52874d561f..2645910314 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -301,4 +301,20 @@ std::string GetCNodeTarget(const AnfNodePtr &node) { } return default_target; } + +bool ContainMultiTarget(const std::vector &nodes) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + std::string last_target = context_ptr->get_param(MS_CTX_DEVICE_TARGET); + for (auto &node : nodes) { + if (node->isa()) { + std::string cur_target = GetCNodeTarget(node); + if (last_target != cur_target) { + return true; + } + last_target = cur_target; + } + } + return false; +} } // namespace mindspore diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 11f1631ba4..2ad81d61f0 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -482,6 +482,13 @@ void reset_id(); using TaggedNodeMap = std::unordered_map; using TaggedGraph = std::pair; std::string GetCNodeTarget(const AnfNodePtr &node); +bool ContainMultiTarget(const std::vector &nodes); +struct GraphSegment { + GraphSegment(const std::vector &nodes, bool is_cut) : nodes_(nodes), is_cut_(is_cut) {} + std::vector nodes_; + bool is_cut_{false}; +}; +using GraphSegmentPtr = std::shared_ptr; } // namespace mindspore #endif // MINDSPORE_CORE_IR_ANF_H_ diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index 3a6808ca71..eb14a03c50 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -647,6 +647,19 @@ ParameterPtr FuncGraph::add_weight(const tensor::MetaTensorPtr &meta_tensor) { return parameter; } +bool FuncGraph::ContainMultiTarget() const { + auto graph_manager = manager(); + MS_EXCEPTION_IF_NULL(graph_manager); + FuncGraphSet graphs = graph_manager->func_graphs(); + for (auto &g : graphs) { + auto nodes = TopoSort(g->get_return()); + if (mindspore::ContainMultiTarget(nodes)) { + return true; + } + } + return false; +} + size_t NewFgSeenGeneration() { static size_t fg_seen_generation = 0; return ++fg_seen_generation; diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index 3702ae13ef..2df1ec4110 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -354,6 +354,7 @@ class FuncGraph : public FuncGraphBase { static void set_drawer(Drawer drawer) { drawer_ = drawer; } std::shared_ptr switch_layer_input() const { return switch_layer_input_; } void set_switch_layer_input(std::shared_ptr switch_layer_input) { switch_layer_input_ = switch_layer_input; } + bool ContainMultiTarget() const; private: // graph is manipulated by manager and others diff --git a/tests/ut/cpp/vm/segment_runner_test.cc b/tests/ut/cpp/vm/segment_runner_test.cc index fb4a13d181..bafda55c55 100644 --- a/tests/ut/cpp/vm/segment_runner_test.cc +++ b/tests/ut/cpp/vm/segment_runner_test.cc @@ -52,21 +52,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) { std::shared_ptr manager = mindspore::Manage(g); BackendPtr b = std::make_shared("vm"); - CompileGraph transform_(b); - auto splits = transform_.SplitNodes(g); + auto graph_partition = std::make_shared(nonlinear_ops, b->name()); + auto segments = graph_partition->Partition(g); VectorRef args({1.0, 2.0}); - std::vector todos(splits.size()); - auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), - [](const BaseRef &seg) -> bool { return utils::isa(seg); }); - todos.resize(std::distance(todos.begin(), it)); - ASSERT_EQ(todos.size(), 1); - - AnfNodePtrList anf_list; - for (auto &item : utils::cast(todos[0])) { - anf_list.push_back(utils::cast(item)); - } - auto convertResult = MsVmConvert(anf_list, ""); + auto convertResult = MsVmConvert(segments[0], ""); auto runResult = (*(convertResult.run))(args); ASSERT_TRUE(runResult.size() == 1 && py::cast(BaseRefToPyData(runResult[0])) == 3.0); } @@ -76,21 +66,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) { std::shared_ptr manager = mindspore::Manage(g); BackendPtr b = std::make_shared("vm"); - CompileGraph transform_(b); - auto splits = transform_.SplitNodes(g); + auto graph_partition = std::make_shared(nonlinear_ops, b->name()); + auto segments = graph_partition->Partition(g); VectorRef args({1.0, 2.0}); - std::vector todos(splits.size()); - auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), - [](const BaseRef &seg) -> bool { return utils::isa(seg); }); - todos.resize(std::distance(todos.begin(), it)); - ASSERT_EQ(todos.size(), 1); - - AnfNodePtrList anf_list; - for (auto &item : utils::cast(todos[0])) { - anf_list.push_back(utils::cast(item)); - } - auto convertResult = MsVmConvert(anf_list, ""); + auto convertResult = MsVmConvert(segments[0], ""); auto runResult = (*(convertResult.run))(args); ASSERT_TRUE(runResult.size() == 1 && py::cast(BaseRefToPyData(runResult[0])) == 2.0); } @@ -100,21 +80,11 @@ TEST_F(TestCompileSegmentRunner, test_if) { std::shared_ptr manager = mindspore::Manage(g); BackendPtr b = std::make_shared("vm"); - CompileGraph transform_(b); - auto splits = transform_.SplitNodes(g); + auto graph_partition = std::make_shared(nonlinear_ops, b->name()); + auto segments = graph_partition->Partition(g); VectorRef args({1.0, 2.0}); - std::vector todos(splits.size()); - auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), - [](const BaseRef &seg) -> bool { return utils::isa(seg); }); - todos.resize(std::distance(todos.begin(), it)); - ASSERT_EQ(todos.size(), 1); - - AnfNodePtrList anf_list; - for (auto &item : utils::cast(todos[0])) { - anf_list.push_back(utils::cast(item)); - } - auto convertResult = MsVmConvert(anf_list, ""); + auto convertResult = MsVmConvert(segments[0], ""); auto runResult = (*(convertResult.run))(args); auto result = py::cast(BaseRefToPyData(runResult[0]));