From 9fddc2e1e48d043f52e6a062eed946dadb9d43ac Mon Sep 17 00:00:00 2001 From: fangzehua Date: Wed, 20 Jan 2021 15:59:54 +0800 Subject: [PATCH] add pipe --- .../cache_embedding/cache_embedding.cc | 301 ++++++++++++++---- .../cache_embedding/cache_embedding.h | 2 +- mindspore/nn/layer/embedding.py | 17 +- 3 files changed, 241 insertions(+), 79 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc b/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc index eff428aae0..2032f09a84 100644 --- a/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc +++ b/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +34,8 @@ namespace parallel { using ParamMap = std::unordered_map; using ParamSet = std::unordered_set; using NodePairList = std::vector>; +using AnfMap = std::unordered_map; +using AnfSet = std::unordered_set; ParamMap AddCacheParameters(const FuncGraphPtr &graph, const ParamSet ¶meter_cache_enable_set) { ParamMap cache_host_params_map; @@ -408,6 +411,7 @@ CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const Param if (sparse_gather_v2_with_cache.empty()) { MS_LOG(EXCEPTION) << "Can not find SparseGatherV2 with cache param."; } + auto indices = sparse_gather_v2_with_cache[0]->input(2); for (auto &ele : sparse_gather_v2_with_cache) { if (ele->input(2) != indices) { @@ -433,13 +437,227 @@ AnfNodePtr FindGatherV2FromSparseGatherV2(const FuncGraphPtr &graph, const AnfNo return gatherv2_nodes[0]; } -void AddCacheEmbedding(const FuncGraphPtr &graph) { +AnfSet FindNoRefParams(const FuncGraphPtr &graph) { + AnfSet no_ref_params; + auto params = graph->parameters(); + for (auto &anf_param : params) { + auto param = anf_param->cast(); + if (!param->has_default()) { + MS_LOG(INFO) << param->DebugString() << " has no default"; + no_ref_params.insert(anf_param); + } + } + return no_ref_params; +} + +void RemoveOriginParamFromSet(const CNodePtr &unique_node, AnfSet *no_ref_params) { + std::queue que; + que.push(unique_node); + while (!que.empty()) { + auto node = que.front(); + que.pop(); + auto node_inputs = node->inputs(); + for (auto &input : node_inputs) { + if (input->isa()) { + que.push(input->cast()); + } else if (input->isa()) { + int num = no_ref_params->erase(input); + if (num > 0) { + MS_LOG(INFO) << "Erase unique_node input from set success."; + return; + } + } + } + } + MS_LOG(EXCEPTION) << "Can not find any parameter that use by Unique."; +} + +AnfNodePtr CreateOutputNodeParam(const FuncGraphPtr &graph, const AnfNodePtr &ori_input, const std::string &name) { + auto ori_input_type = ori_input->Type(); + auto ori_input_element_type = ori_input_type->cast()->element(); + auto ori_input_type_id = ori_input_element_type->type_id(); + auto ori_input_shp = ori_input->Shape(); + auto input_shp = ori_input_shp->cast(); + auto input_shape = input_shp->shape(); + auto new_tensor = std::make_shared(ori_input_type_id, input_shape); + ParamInfoPtr new_param_info = std::make_shared(); + auto new_param_name = name + "_pipe"; + new_param_info->set_name(new_param_name); + new_tensor->set_param_info(new_param_info); + auto new_param = graph->AddWeightParameter(new_param_name); + new_param->set_default_param(MakeValue(new_tensor)); + auto abs_tensor = new_tensor->ToAbstract(); + new_param->set_abstract(abs_tensor); + return new_param->cast(); +} + +AnfMap CreateOtherPipeParams(const FuncGraphPtr &graph, const AnfSet &no_ref_params) { + AnfMap no_ref_pipe_param_map; + for (auto ¶m : no_ref_params) { + auto ori_param = param->cast(); + auto ori_name = ori_param->name(); + auto new_param = CreateOutputNodeParam(graph, param, ori_name); + no_ref_pipe_param_map[param] = new_param; + } + return no_ref_pipe_param_map; +} + +AnfNodePtr CreateAssign(const FuncGraphPtr &graph, const AnfNodePtr &res_param, const AnfNodePtr &src_param, + bool is_dynamic = false) { + auto assign_prim = prim::kPrimAssign; + if (is_dynamic) { + assign_prim = prim::kPrimDynamicAssign; + assign_prim->set_attr(kAttrPrimitiveTarget, MakeValue("CPU")); + } + std::vector assign_nodes{NewValueNode(assign_prim), res_param, src_param}; + auto assign_status = graph->NewCNode(assign_nodes); + return assign_status; +} + +AnfNodePtr FindCNodeOutput(const FuncGraphPtr &graph, const AnfNodePtr &node, int64_t index) { + auto manager = graph->manager(); + auto node_users = manager->node_users()[node]; + for (auto &node_user : node_users) { + if (IsPrimitiveCNode(node_user.first, prim::kPrimTupleGetItem)) { + auto cnode = node_user.first->cast(); + auto node_index = cnode->input(2); + if (node_index->isa()) { + auto value_node = node_index->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto item_idx = GetValue(value_node->value()); + if (item_idx == index) { + return node_user.first; + } + } + } + } + MS_LOG(EXCEPTION) << "Can't not find " << node->DebugString() << ", outputs:" << index; +} + +AnfNodePtrList ReplaceNoRefToParams(const FuncGraphPtr &graph, const AnfMap &no_ref_pipe_param_map, + const AnfNodePtr &cache_idx_param, const AnfNodePtr &cache_idx, + const AnfNodePtr &sparse_gatherv2_indices) { + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto node_users = manager->node_users(); + AnfNodePtrList control_depend_list; + // add other no ref pipe param and unique index dense + for (auto &ele : no_ref_pipe_param_map) { + auto user_set = node_users[ele.first]; + auto assign_status = CreateAssign(graph, ele.second, ele.first); + for (auto user_node : user_set) { + auto control_depend = CreateControlDepend(graph, user_node.first, assign_status); + control_depend_list.emplace_back(control_depend); + } + if (!manager->Replace(ele.first, ele.second)) { + MS_LOG(EXCEPTION) << "pipe param: " << ele.first->DebugString() << ", replace node failed."; + } + } + + // add cache idx param + auto dynamic_assgin_status = CreateAssign(graph, cache_idx_param, cache_idx, true); + auto indices_user_set = node_users[sparse_gatherv2_indices]; + for (auto &user_node : indices_user_set) { + auto control_depend = CreateControlDepend(graph, user_node.first, dynamic_assgin_status); + control_depend_list.emplace_back(control_depend); + } + if (!manager->Replace(sparse_gatherv2_indices, cache_idx_param)) { + MS_LOG(EXCEPTION) << "cache idx param: " << cache_idx_param->DebugString() << ", replace node failed."; + } + return control_depend_list; +} + +void CacheEmbeddingForTrain(const FuncGraphPtr &graph, bool is_pipe, const CNodePtrList &cnodes, + const CNodePtr &unique_node, const ParamSet ¶m_cache_enable_set) { MS_EXCEPTION_IF_NULL(graph); - std::list orders = graph->GetOrderedCnodes(); - CNodePtrList cnodes(orders.begin(), orders.end()); + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); size_t cnodes_size = cnodes.size(); + auto cache_host_params_map = AddCacheParameters(graph, param_cache_enable_set); + auto param_set = MapKeysToSet(cache_host_params_map); + ReplaceCacheParams(graph, cache_host_params_map); + graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true); + MS_LOG(INFO) << "Graph is set cache enable."; + + CNodePtrList sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_set); + auto unique_node_output_0 = CreateTupleGetItem(graph, unique_node, 0); + auto map_cache_idx = CreateMapCacheIdx(graph, unique_node_output_0, cache_host_params_map); + + AnfNodePtrList map_cache_idx_node_outputs; + CreateTupleGetItems(graph, map_cache_idx, &map_cache_idx_node_outputs); + + auto node_pair_list = CreateEmbSwapUpdate(graph, cache_host_params_map, map_cache_idx_node_outputs); + AnfNodePtrList invalid_nodes; + auto cache_idx = map_cache_idx_node_outputs[0]; + if (!is_pipe) { + if (!manager->Replace(sparse_gatherv2_with_cache[0]->input(2), cache_idx)) { + MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed"; + } + for (auto &ele : node_pair_list) { + std::transform(sparse_gatherv2_with_cache.begin(), sparse_gatherv2_with_cache.end(), + std::back_inserter(invalid_nodes), [&graph, &ele](const AnfNodePtr &sparse_gatherv2) { + return CreateControlDepend(graph, ele.first, sparse_gatherv2); + }); + invalid_nodes.emplace_back(ele.second); + } + } else { + auto cache_idx_param = CreateOutputNodeParam(graph, unique_node->input(1), std::string("cache_idx")); + auto unique_index_reverse = FindCNodeOutput(graph, unique_node, 1); + auto unique_index_param = CreateOutputNodeParam(graph, unique_index_reverse, std::string("index_dense")); + auto no_ref_params = FindNoRefParams(graph); + RemoveOriginParamFromSet(unique_node, &no_ref_params); + auto no_ref_param_map = CreateOtherPipeParams(graph, no_ref_params); + no_ref_param_map[unique_index_reverse] = unique_index_param; + auto control_depend_list = ReplaceNoRefToParams(graph, no_ref_param_map, cache_idx_param, cache_idx, + sparse_gatherv2_with_cache[0]->input(2)); + std::copy(control_depend_list.begin(), control_depend_list.end(), std::back_inserter(invalid_nodes)); + std::transform(node_pair_list.begin(), node_pair_list.end(), std::back_inserter(invalid_nodes), + [](const std::pair &pair) { return pair.second; }); + } + AnfNodePtr last_node = cnodes[cnodes_size - 1]; + CNodePtr return_node; + if (last_node->isa()) { + return_node = last_node->cast(); + } + MS_EXCEPTION_IF_NULL(return_node); + if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) { + MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode."; + } + if (return_node->inputs().size() < 2) { + MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2."; + } + + auto depend_node = CreateDepend(graph, invalid_nodes, return_node->input(1)); + if (!manager->Replace(return_node->input(1), depend_node)) { + MS_LOG(EXCEPTION) << "Depend replace node failed"; + } +} + +void CacheEmbeddingForEval(const FuncGraphPtr &graph, const CNodePtrList &cnodes, const CNodePtr &unique_node, + const ParamSet ¶m_cache_enable_set) { + MS_EXCEPTION_IF_NULL(graph); auto manager = graph->manager(); MS_EXCEPTION_IF_NULL(manager); + graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true); + MS_LOG(INFO) << "Graph is set cache enable."; + // replace GatherV2 to EmbeddingLookupCPU + auto indices = unique_node->input(1); + auto sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_cache_enable_set); + for (auto &ele : sparse_gatherv2_with_cache) { + auto anf_ele = ele->cast(); + auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele); + auto param = ele->input(1)->cast(); + auto embedding_lookup = CreateEmbeddingLookup(graph, param, indices); + if (!manager->Replace(gatherv2, embedding_lookup)) { + MS_LOG(EXCEPTION) << "Depend replace node failed"; + } + } +} + +void AddCacheEmbedding(const FuncGraphPtr &graph, bool is_pipe) { + MS_EXCEPTION_IF_NULL(graph); + std::list orders = graph->GetOrderedCnodes(); + CNodePtrList cnodes(orders.begin(), orders.end()); bool training = graph->has_flag("training"); auto param_cache_enable_set = FindParamCacheEnable(graph); if (param_cache_enable_set.empty()) { @@ -451,6 +669,12 @@ void AddCacheEmbedding(const FuncGraphPtr &graph) { if (!CheckHostCacheParamSize(param_cache_enable_set)) { return; } + auto unique_cache_enable = FindUniqueCacheEnable(cnodes); + if (unique_cache_enable.empty()) { + MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable."; + return; + } + auto unique_node = unique_cache_enable[0]; if (training) { // If training, create cache parameters corresponding to the host params with is cache_enable. // Replace the host params. Create hashmap then insert MapCacheIdx op after Unique with has 'cache_enable' attr. @@ -460,75 +684,14 @@ void AddCacheEmbedding(const FuncGraphPtr &graph) { // flush miss values to cache params and write back old values to host params. // If no use pipe in training, EmbeddingLookup and CacheSwapTable must execute before SparseGatherV2, so add // ControlDepend between them. And add Depend for UpdateCache op and ControlDepnd op to add nodes into graph. - auto unique_cache_enable = FindUniqueCacheEnable(cnodes); - if (unique_cache_enable.empty()) { - MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable."; - return; - } - auto cache_host_params_map = AddCacheParameters(graph, param_cache_enable_set); - auto param_set = MapKeysToSet(cache_host_params_map); - ReplaceCacheParams(graph, cache_host_params_map); - graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true); - auto unique_node = unique_cache_enable[0]; - - CNodePtrList sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_set); - auto unique_node_output_0 = CreateTupleGetItem(graph, unique_node, 0); - auto map_cache_idx = CreateMapCacheIdx(graph, unique_node_output_0, cache_host_params_map); - - AnfNodePtrList map_cache_idx_node_outputs; - CreateTupleGetItems(graph, map_cache_idx, &map_cache_idx_node_outputs); - - if (!manager->Replace(sparse_gatherv2_with_cache[0]->input(2), map_cache_idx_node_outputs[0])) { - MS_LOG(EXCEPTION) << "MapCacheIdx output[0] replace node failed"; - } - - auto node_pair_list = CreateEmbSwapUpdate(graph, cache_host_params_map, map_cache_idx_node_outputs); - - AnfNodePtr last_node = cnodes[cnodes_size - 1]; - CNodePtr return_node; - if (last_node->isa()) { - return_node = last_node->cast(); - } - MS_EXCEPTION_IF_NULL(return_node); - if (!IsPrimitiveCNode(return_node, prim::kPrimReturn)) { - MS_LOG(EXCEPTION) << "The last cnode after sorting, not return cnode."; - } - if (return_node->inputs().size() < 2) { - MS_LOG(EXCEPTION) << "Number of return node inputs should be great than or equal to 2."; - } - AnfNodePtrList invalid_nodes; - for (auto &ele : node_pair_list) { - std::transform(sparse_gatherv2_with_cache.begin(), sparse_gatherv2_with_cache.end(), - std::back_inserter(invalid_nodes), [&graph, &ele](const AnfNodePtr &sparse_gatherv2) { - return CreateControlDepend(graph, ele.first, sparse_gatherv2); - }); - invalid_nodes.emplace_back(ele.second); - } - auto depend_node = CreateDepend(graph, invalid_nodes, return_node->input(1)); - if (!manager->Replace(return_node->input(1), depend_node)) { - MS_LOG(EXCEPTION) << "Depend replace node failed"; - } + // If use pipe in training, create parameters for no ref param such as labels and MapCacheIdx output[0] and + // Unique output[1], in each step, it will train the data from last step, so that can hide the time of Unique + // and other cpu kernels. So in the first step, it's fake data. + CacheEmbeddingForTrain(graph, is_pipe, cnodes, unique_node, param_cache_enable_set); } else { // If eval, Use EmbeddingLookup(CPU) op to replace GatherV2. // The network is the same as Host-Device mode. - auto unique_cache_enable = FindUniqueCacheEnable(cnodes); - if (unique_cache_enable.empty()) { - MS_LOG(WARNING) << "Parameters have cache enable, but not find Unique op cache enable."; - return; - } - graph->set_flag(GRAPH_FLAG_CACHE_ENABLE, true); - // replace GatherV2 to EmbeddingLookupCPU - auto indices = unique_cache_enable[0]->input(1); - auto sparse_gatherv2_with_cache = FindSparseGatherV2WithCache(cnodes, param_cache_enable_set); - for (auto &ele : sparse_gatherv2_with_cache) { - auto anf_ele = ele->cast(); - auto gatherv2 = FindGatherV2FromSparseGatherV2(graph, anf_ele); - auto param = ele->input(1)->cast(); - auto embedding_lookup = CreateEmbeddingLookup(graph, param, indices); - if (!manager->Replace(gatherv2, embedding_lookup)) { - MS_LOG(EXCEPTION) << "Depend replace node failed"; - } - } + CacheEmbeddingForEval(graph, cnodes, unique_node, param_cache_enable_set); } } } // namespace parallel diff --git a/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.h b/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.h index 14772a319b..a54e283725 100644 --- a/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.h +++ b/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.h @@ -22,7 +22,7 @@ namespace mindspore { namespace parallel { // Automatically adding control depend based on effect order and side effect analysis. -void AddCacheEmbedding(const FuncGraphPtr &graph); +void AddCacheEmbedding(const FuncGraphPtr &graph, bool is_pipe = false); } // namespace parallel } // namespace mindspore #endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_ diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index 90cb78fabb..5618223ef5 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -21,7 +21,7 @@ from mindspore.ops import functional as F from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from mindspore.communication.management import get_group_size -from mindspore.context import ParallelMode +from mindspore.context import ParallelMode, get_context from mindspore.parallel._utils import _get_parallel_mode, _get_full_batch from mindspore.parallel._ps_context import _insert_hash_table_size, _set_cache_enable, _is_role_worker, _get_ps_context from mindspore._checkparam import Rel @@ -278,7 +278,7 @@ class EmbeddingLookup(Cell): raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get " + str(slice_mode)) if self.cache_enable and not enable_ps: - if is_auto_parallel: + if parallel_mode != ParallelMode.STAND_ALONE: raise ValueError("parallel mode haven't supported cache enable yet.") self._set_cache_enable() self.embedding_table.unique = self.forward_unique @@ -288,15 +288,14 @@ class EmbeddingLookup(Cell): self.max_norm = Tensor(self.max_norm, dtype=mstype.float32) def _set_cache_enable(self): - """EmbeddingLookup cache check for not ps env.""" + """EmbeddingLookup cache check for not ps env, which is only support 'ascend'.""" if self.target != 'DEVICE': - logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, " - "so it will be ignored.") - return + raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target.") if not self.sparse: - logger.warning("The configuration of 'vocab_cache_size' is valid only 'sparse' is true, " - "so it will be ignored.") - return + raise ValueError("The configuration of 'vocab_cache_size' is valid only 'sparse' is true.") + if get_context("device_target") != 'Ascend': + raise ValueError("The configuration of 'vocab_cache_size' is valid only in 'ascend'.") + logger.info("EmbeddingLookup cache enable takes effect.") self.forward_unique = True self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU')