From caac6bce5c3445a5f444cd9a18521f4024972171 Mon Sep 17 00:00:00 2001 From: ch-l Date: Wed, 29 Apr 2020 10:16:32 +0200 Subject: [PATCH] adjustements w.r.t. distributed execution --- .../auto_parallel/rec_core/rec_cost.cc | 24 +++++- .../auto_parallel/rec_core/rec_cost.h | 2 +- .../rec_core/rec_generate_strategy.cc | 16 ++-- .../auto_parallel/rec_core/rec_parse_graph.cc | 26 +----- .../auto_parallel/rec_core/rec_parse_graph.h | 3 - .../auto_parallel/rec_core/rec_partition.cc | 81 +------------------ .../auto_parallel/rec_core/rec_partition.h | 6 -- 7 files changed, 32 insertions(+), 126 deletions(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc index 3fea107a73..e5ba59425c 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.cc @@ -296,10 +296,10 @@ double CostConvolution::GetMinCostIn(const Graph::NodeType &node) { static_cast(op.arguments[1].tensor_shape.shape_n * op.arguments[1].tensor_str.str_n) * static_cast(op.arguments[1].tensor_shape.shape_w * op.arguments[1].tensor_str.str_w) * static_cast(op.arguments[1].tensor_shape.shape_c * op.arguments[1].tensor_str.str_c); - int tensor_out = static_cast(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_shape.shape_w) * - static_cast(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_shape.shape_c) * - static_cast(node.tensor_parm.tensor_str.str_h * node.tensor_parm.tensor_str.str_w) * - static_cast(node.tensor_parm.tensor_str.str_n * node.tensor_parm.tensor_str.str_c); + int tensor_out = static_cast(node.tensor_parm.tensor_shape.shape_h * node.tensor_parm.tensor_str.str_h) * + static_cast(node.tensor_parm.tensor_shape.shape_n * node.tensor_parm.tensor_str.str_n) * + static_cast(node.tensor_parm.tensor_shape.shape_w * node.tensor_parm.tensor_str.str_w) * + static_cast(node.tensor_parm.tensor_shape.shape_c * node.tensor_parm.tensor_str.str_c); std::vector cost_in; cost_in.push_back(StrDimB(tensor_filter)); @@ -628,6 +628,22 @@ StrategyRec CostCommon::ChoseStr(const std::vector &cost_op, StrategyRec return str; } +// Get weight for BN +double CostBatchNorm::GetMinCostIn(const OperatorRec &op) { + int tensor = static_cast(op.arguments[0].tensor_shape.shape_h * op.arguments[0].tensor_str.str_h) * + static_cast(op.arguments[0].tensor_shape.shape_n * op.arguments[0].tensor_str.str_n) * + static_cast(op.arguments[0].tensor_shape.shape_w * op.arguments[0].tensor_str.str_w) * + static_cast(op.arguments[0].tensor_shape.shape_c * op.arguments[0].tensor_str.str_c); + + std::vector cost_in; + cost_in.push_back(StrDimB(tensor) * 1.2); + cost_in.push_back(DOUBLE_MAX); + cost_in.push_back(StrDimH(tensor) * 1.2); + cost_in.push_back(StrDimW(tensor) * 1.2); + + return *min_element(cost_in.begin(), cost_in.end()); +} + // Get optimal strategy for BN StrategyRec CostBatchNorm::GetOptimalStr(const Graph::NodeType &node, const std::vector> &node_name_to_strategy, diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h index 85e5e5ea94..315c081d67 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_cost.h @@ -213,7 +213,7 @@ class CostBatchNorm { const std::vector> &node_name_to_strategy, const Graph &graph); - double GetMinCostIn() const { return 0.0; } + double GetMinCostIn(const OperatorRec &op); private: double StrDimB(int32_t Tensor) { diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index e942c8005f..42b3bfc72e 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -132,8 +132,9 @@ std::vector MakeOriginalStrategy(const std::vector= ops.size()) { MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; } - if (iter_op_inputs >= ops[iter_ops]->strategy()->GetInputDim().size()) + if (iter_op_inputs >= ops[iter_ops]->strategy()->GetInputDim().size()) { MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; + } size_t input_size = ops[iter_ops]->strategy()->GetInputDim()[iter_op_inputs].size(); for (size_t dim = 0; dim < input_size; dim++) { s.push_back(1); @@ -161,8 +162,9 @@ std::vector MakeDataParallelStrategy(const std::vectorstrategy(); - if (iter_op_inputs >= origin_strategy->GetInputDim().size()) + if (iter_op_inputs >= origin_strategy->GetInputDim().size()) { MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; + } size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); for (size_t dim = 0; dim < input_size; dim++) { if (dim == 0 && input_size == 4) { @@ -198,9 +200,9 @@ std::vector PrepareStrategy(const std::shared_ptr &graph, return MakeOriginalStrategy(ops, iter_ops, iter_op_inputs); } else if (type == RELU) { return MakeRecSearchStrategy(graph, iter_ops, iter_op_inputs); - } else if (type == BATCH_NORM || (type == FUSE_BATCH_NORM)) { + } else if ((type == BATCH_NORM) || (type == FUSE_BATCH_NORM)) { return PrepareBN(graph, iter_ops, iter_op_inputs); - } else if (type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { + } else if (type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { return PrepareSparse(iter_op_inputs); } else { return MakeDataParallelStrategy(ops, iter_ops, iter_op_inputs); @@ -224,12 +226,6 @@ void MaskSpecialOps(std::shared_ptr graph) { node.apply.arguments[1].tensor_str.str_c = 1; node.apply.arguments[1].tensor_str.str_h = 1; node.apply.arguments[1].tensor_str.str_w = 1; - } else if (node.apply.op_type == kRecBiasAdd || node.apply.op_type == kRecMatMul) { - // For MatMul and BiasAdd - node.apply.arguments[0].tensor_str.str_h = 1; - node.apply.arguments[0].tensor_str.str_w = 1; - node.apply.arguments[1].tensor_str.str_h = 1; - node.apply.arguments[1].tensor_str.str_w = 1; } } } diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc index b9b1b7b914..ada22fef9a 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -58,7 +58,8 @@ Graph::NodeType MakeNewOperator(std::vector> ops, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1], ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]); } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) { - NewOp.tensor_parm = Fill2DTensor(ops, iter_ops, NewOp); + NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], + ops[iter_ops]->outputs_tensor_info()[0].shape()[1]); } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) { NewOp.tensor_parm = MakeTensor(1, 1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0]); } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 0) { @@ -71,29 +72,6 @@ Graph::NodeType MakeNewOperator(std::vector> ops, return NewOp; } -TensorParam Fill2DTensor(const std::vector> &ops, const size_t iter_ops, - Graph::NodeType NewTensor) { - if (NewTensor.apply.op_type == OperatorType::kRecMatMul) { - auto attrs = ops[iter_ops]->attrs(); - bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); - bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); - if (transpose_a) { - NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[1], - ops[iter_ops]->outputs_tensor_info()[0].shape()[0]); - } else if (transpose_b) { - NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[1], - ops[iter_ops]->outputs_tensor_info()[0].shape()[0]); - } else { - NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], - ops[iter_ops]->outputs_tensor_info()[0].shape()[1]); - } - } else { - NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], - ops[iter_ops]->outputs_tensor_info()[0].shape()[1]); - } - return NewTensor.tensor_parm; -} - OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, Graph::NodeType NewTensor) { for (size_t iter_input_tensors = 0; iter_input_tensors < ops[iter_ops]->inputs_tensor_info().size(); diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h index 17a8174dde..2b1d0c55ed 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -53,9 +53,6 @@ const TensorParam MakeTensor(int n, int c, int h, int w); Graph::NodeType MakeNewOperator(std::vector> ops, size_t iter_ops); -TensorParam Fill2DTensor(const std::vector> &ops, const size_t iter_ops, - Graph::NodeType NewTensor); - OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, Graph::NodeType NewTensor); diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc index 5fcaefcb47..3527c18079 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc @@ -73,7 +73,7 @@ double GetWeights(const Graph::NodeType &node) { // For BatchNorm auto cost_ptr = std::make_shared(); - return cost_ptr->GetMinCostIn(); + return cost_ptr->GetMinCostIn(op); } else if (op.op_type == OperatorType::kRecOneHot || op.op_type == OperatorType::kRecLog || op.op_type == OperatorType::kRecExp || op.op_type == OperatorType::kRecAdd || op.op_type == OperatorType::kRecSub || op.op_type == OperatorType::kRecMul || @@ -108,8 +108,8 @@ std::vector SortByWeight(const std::shared_ptr graph) { } } - // Do sorting. - sort(weight_to_node_index.begin(), weight_to_node_index.end()); + // Ordering ops aka nodes of the graph + std::sort(weight_to_node_index.begin(), weight_to_node_index.end()); // Store the result in node_index_by_weights. uint64_t size = weight_to_node_index.size(); @@ -231,7 +231,6 @@ Status PartitionForAllDevices(const size_t num_device, const double device_memor } } - InferUndecideStrategy(graph); if (DevicesMemoryControl(device_memory, graph) != SUCCESS) { return FAILED; } else { @@ -257,80 +256,6 @@ Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) { return Node; } -// Check Strategy for the same tensor between op. -void InferUndecideStrategy(std::shared_ptr graph) { - MS_EXCEPTION_IF_NULL(graph); - - uint64_t iter_nodes = graph->nodes.size(); - - // For all the nodes in the graph - for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) { - // If this target node is an operator, find it's adjecent op's strategy; - if (graph->nodes[i_node].info == 0) { - // Try to apply last op's strategy. - ApplyLastStrategy(i_node, graph); - // Try to apply next op's strategy. - ApplyNextStrategy(i_node, graph); - } - } -} - -void ApplyLastStrategy(const uint64_t node_index, std::shared_ptr graph) { - Graph::NodeType &target_node = graph->nodes[node_index]; - - // Number of node-in - size_t num_node_in = target_node.node_in.size(); - - // Find forward op and copy strategy if meets the limits. - for (size_t index = 0; index < num_node_in; index++) { - if (graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_n <= - target_node.apply.arguments[0].tensor_str.str_n && - graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_c <= - target_node.apply.arguments[0].tensor_str.str_c && - graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_h <= - target_node.apply.arguments[0].tensor_str.str_h && - graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_w <= - target_node.apply.arguments[0].tensor_str.str_w) { - target_node.apply.arguments[0].tensor_str.str_n = - graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_n; - target_node.apply.arguments[0].tensor_str.str_c = - graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_c; - target_node.apply.arguments[0].tensor_str.str_h = - graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_h; - target_node.apply.arguments[0].tensor_str.str_w = - graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_w; - } - } -} - -void ApplyNextStrategy(const uint64_t node_index, std::shared_ptr graph) { - Graph::NodeType &target_node = graph->nodes[node_index]; - - // Number of node-out - size_t num_node_out = target_node.node_out.size(); - - // Find backward op and copy strategy if meets the limits. - for (size_t index = 0; index < num_node_out; index++) { - if (graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_n <= - target_node.tensor_parm.tensor_str.str_n && - graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_c <= - target_node.tensor_parm.tensor_str.str_c && - graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_h <= - target_node.tensor_parm.tensor_str.str_h && - graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_w <= - target_node.tensor_parm.tensor_str.str_w) { - target_node.tensor_parm.tensor_str.str_n = - graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_n; - target_node.tensor_parm.tensor_str.str_c = - graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_c; - target_node.tensor_parm.tensor_str.str_h = - graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_h; - target_node.tensor_parm.tensor_str.str_w = - graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_w; - } - } -} - Status DevicesMemoryControl(const double device_memory, std::shared_ptr graph) { MS_EXCEPTION_IF_NULL(graph); diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h index e22b11542a..fc504b3cb2 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h @@ -44,12 +44,6 @@ Status PartitionForAllDevices(const size_t num_device, const double device_memor Graph::NodeType ApplyStrToTensor(Graph::NodeType Node); -void InferUndecideStrategy(std::shared_ptr graph); - -void ApplyLastStrategy(const uint64_t node_index, std::shared_ptr graph); - -void ApplyNextStrategy(const uint64_t node_index, std::shared_ptr graph); - Status DevicesMemoryControl(const double device_memory, std::shared_ptr graph); size_t GetDataTypeSize(const TensorType &type);