From 9febf7fdf511ab930cfb9a5f2d2e8c6c13a3d8e8 Mon Sep 17 00:00:00 2001 From: hongxing Date: Sun, 28 Jun 2020 16:21:11 +0200 Subject: [PATCH] support GatherV2P --- .../rec_core/rec_generate_strategy.cc | 31 +++++++++++++++++-- .../rec_core/rec_generate_strategy.h | 3 +- .../auto_parallel/rec_core/rec_parse_graph.cc | 4 +-- .../auto_parallel/rec_core/rec_parse_graph.h | 4 +-- .../auto_parallel/rec_core/rec_partition.cc | 9 +++--- .../auto_parallel/rec_core/rec_partition.h | 8 ++--- 6 files changed, 43 insertions(+), 16 deletions(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index b8a57ae997..9de71231c0 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -164,9 +164,34 @@ std::vector> PrepareOneHot(const std::shared_ptr &gr return strategies; } -std::vector> PrepareGatherV2(const std::shared_ptr> &s) { +std::vector> PrepareGatherV2(const std::vector> &ops, + const size_t iter_ops, std::vector s) { std::vector> strategies; - strategies.push_back(*s); + + int32_t axis = 0; + auto axis_input = GetValue(ops[iter_ops]->input_value().at(2)); + if (axis_input < 0) { + axis_input += SizeToInt(ops[iter_ops]->inputs_tensor_info()[0].shape().size()); + } + axis = axis_input; + if (axis >= SizeToInt(s.size())) { + MS_LOG(EXCEPTION) << "Failure: GatherV2' axis out of range."; + } + s[axis] = 1; + strategies.push_back(s); + + auto pos = ops[iter_ops]->name().find("Info"); + auto name = ops[iter_ops]->name().substr(0, pos); + if (name == "GatherV2") { + return strategies; + } + + std::vector s_indices; + for (size_t i = 0; i < ops[iter_ops]->inputs_tensor_info()[1].shape().size(); i++) { + s_indices.push_back(1); + } + strategies.push_back(s_indices); + return strategies; } @@ -607,7 +632,7 @@ std::vector> GenerateStrategiesFromStrategy(const std::vect return PrepareBiasAdd(s_ptr); } if (ops[iter_ops]->type() == GATHERV2) { - return PrepareGatherV2(s_ptr); + return PrepareGatherV2(ops, iter_ops, basic_stra); } if (ops[iter_ops]->type() == L2_NORMALIZE) { return PrepareL2Normalize(ops, iter_ops, basic_stra); diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h index 1e8080f2b7..e82efe6798 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -38,7 +38,8 @@ std::vector> PrepareBiasAdd(const std::shared_ptr> PrepareOneHot(const std::shared_ptr &graph, const std::vector> &ops, const size_t iter_graph, const size_t iter_ops); -std::vector> PrepareGatherV2(const std::shared_ptr> &s); +std::vector> PrepareGatherV2(const std::vector> &ops, + const size_t iter_ops, std::vector s); std::vector> PrepareL2Normalize(const std::vector> &ops, const size_t iter_ops, std::vector s); std::vector> MakeRecSearchStrategy(const std::shared_ptr &graph, diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc index 190a716063..c0412e9108 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -40,7 +40,7 @@ const TensorParam MakeTensor(int n, int c, int h, int w) { return tensor; } -Graph::NodeType MakeNewOperator(std::vector> ops, size_t iter_ops) { +Graph::NodeType MakeNewOperator(const std::vector> &ops, size_t iter_ops) { Graph::NodeType NewOp; NewOp.name = ops[iter_ops]->name(); NewOp.info = InfoType::kApplication; @@ -140,7 +140,7 @@ std::shared_ptr ParseGraph(const std::vector> &input_tensor_names, std::shared_ptr graph) { +void MakeEdge(const std::vector> &input_tensor_names, const std::shared_ptr &graph) { for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) { for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) { size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]); diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h index a696e88332..53abefd1c8 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -110,7 +110,7 @@ const std::map DictOpType{ const TensorParam MakeTensor(int n, int c, int h, int w); -Graph::NodeType MakeNewOperator(std::vector> ops, size_t iter_ops); +Graph::NodeType MakeNewOperator(const std::vector> &ops, size_t iter_ops); OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, Graph::NodeType NewTensor); @@ -121,7 +121,7 @@ TensorParam Complete2DInputs(const std::vector> &o std::shared_ptr ParseGraph(const std::vector> &ops, const std::vector> &input_tensor_names); -void MakeEdge(const std::vector> &input_tensor_names, std::shared_ptr graph); +void MakeEdge(const std::vector> &input_tensor_names, const std::shared_ptr &graph); size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_names, const std::string &input_name); diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc index 0f6e736d52..d5200f54d8 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc @@ -93,7 +93,7 @@ double GetWeights(const Graph::NodeType &node) { } // Sort all the nodes by their weights -std::vector SortByWeight(const std::shared_ptr graph) { +std::vector SortByWeight(const std::shared_ptr &graph) { MS_EXCEPTION_IF_NULL(graph); std::vector> weight_to_node_index; @@ -124,7 +124,7 @@ std::vector SortByWeight(const std::shared_ptr graph) { // Get optimal strategy to partition the target node StrategyRec PartitionNode(const Graph::NodeType &node, const std::vector> &node_name_to_strategy, - std::shared_ptr graph) { + const std::shared_ptr &graph) { bool enable_conv_chw_partition = false; MS_EXCEPTION_IF_NULL(graph); @@ -191,7 +191,8 @@ StrategyRec PartitionNode(const Graph::NodeType &node, } // Parttion graph into all devices. -Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr graph) { +Status PartitionForAllDevices(const size_t num_device, const double device_memory, + const std::shared_ptr &graph) { if (num_device < 1) { MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << "."; } @@ -261,7 +262,7 @@ Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) { return Node; } -Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr graph) { +Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr &graph) { MS_EXCEPTION_IF_NULL(graph); if (num_device == 0) { MS_LOG(EXCEPTION) << "Failure: device number is 0."; diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h index b2fbeddebd..c98f3317f8 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h @@ -32,19 +32,19 @@ namespace mindspore { namespace parallel { -std::vector SortByWeight(const std::shared_ptr graph); +std::vector SortByWeight(const std::shared_ptr &graph); double GetWeights(const Graph::NodeType &node); StrategyRec PartitionNode(const Graph::NodeType &node, const std::vector> &node_name_to_strategy, - std::shared_ptr graph); + const std::shared_ptr &graph); -Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr graph); +Status PartitionForAllDevices(const size_t num_device, const double device_memory, const std::shared_ptr &graph); Graph::NodeType ApplyStrToTensor(Graph::NodeType Node); -Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr graph); +Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr &graph); size_t GetDataTypeSize(const TensorType &type); } // namespace parallel