diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc index 24ad8ac203..81e0eaa2dd 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.cc @@ -29,52 +29,55 @@ namespace mindspore { namespace parallel { -#define DEVICE_MEMORY 1024.0 * 1024.0 * 1024.0 // 1GB // Get the target node's weight for sorting. double GetWeights(const Graph::NodeType &node) { const OperatorRec &op = node.apply; - if (op.op_type == 0) { + if (op.op_type == OperatorType::kRecMatMul) { // For MatMul auto cost_ptr = std::make_shared(); return cost_ptr->GetMinCostIn(op); - } else if (op.op_type == 1) { + } else if (op.op_type == OperatorType::kRecConvolution) { // For Convolution auto cost_ptr = std::make_shared(); return cost_ptr->GetMinCostIn(node); - } else if (op.op_type == 2) { + } else if (op.op_type == OperatorType::kRecPooling) { // For Pooling auto cost_ptr = std::make_shared(); return cost_ptr->GetMinCostIn(); - } else if (op.op_type == 3) { + } else if (op.op_type == OperatorType::kRecAdd) { // For Add auto cost_ptr = std::make_shared(); return cost_ptr->GetMinCostIn(); - } else if (op.op_type == 4 || op.op_type == 7 || op.op_type == 9) { + } else if (op.op_type == OperatorType::kRecSoftmax || op.op_type == OperatorType::kRecReLU || + op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) { // For Softmax & || Activation auto cost_ptr = std::make_shared(); return cost_ptr->GetMinCostIn(); - } else if (op.op_type == 5) { + } else if (op.op_type == OperatorType::kRecReshape) { // For Reshape auto cost_ptr = std::make_shared(); return cost_ptr->GetMinCostIn(); - } else if (op.op_type == 6) { + } else if (op.op_type == OperatorType::kRecBiasAdd) { // For BiasAdd auto cost_ptr = std::make_shared(); return cost_ptr->GetMinCostIn(); - } else if (op.op_type == 8) { + } else if (op.op_type == OperatorType::kRecBatchNorm) { // For BatchNorm auto cost_ptr = std::make_shared(); return cost_ptr->GetMinCostIn(); + } else if (op.op_type == OperatorType::kRecUnkownType) { + // For unknown type + return 0.0; } else { MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed."; } @@ -155,13 +158,17 @@ StrategyRec PartitionNode(const Graph::NodeType &node, auto cost_ptr = std::make_shared(); return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); + } else if (node.apply.op_type == 10) { + // For unknown type + StrategyRec default_strategy; + return default_strategy; } else { MS_LOG(EXCEPTION) << "Failure: Partition Operator failed."; } } // Parttion graph into all devices. -Status PartitionForAllDevices(const size_t num_device, std::shared_ptr graph) { +Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr graph) { if (num_device < 1) { MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << "."; } @@ -207,7 +214,7 @@ Status PartitionForAllDevices(const size_t num_device, std::shared_ptr gr } InferUndecideStrategy(graph); - if (DevicesMemoryControl(graph) != SUCCESS) { + if (DevicesMemoryControl(device_memory, graph) != SUCCESS) { return FAILED; } else { return SUCCESS; @@ -306,15 +313,15 @@ void ApplyNextStrategy(const uint64_t node_index, std::shared_ptr graph) } } -Status DevicesMemoryControl(std::shared_ptr graph) { +Status DevicesMemoryControl(const double device_memory, std::shared_ptr graph) { MS_EXCEPTION_IF_NULL(graph); uint64_t iter_nodes = graph->nodes.size(); + double used_memory = 0.0; for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) { if (graph->nodes[i_node].info == 0) { Graph::NodeType &Node = graph->nodes[i_node]; - double used_memory = 0.0; for (int index = 0; index < 2; index++) { used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n * @@ -329,12 +336,12 @@ Status DevicesMemoryControl(std::shared_ptr graph) { Node.tensor_parm.tensor_str.str_h * Node.tensor_parm.tensor_shape.shape_h * Node.tensor_parm.tensor_str.str_w * Node.tensor_parm.tensor_shape.shape_w * GetDataTypeSize(Node.tensor_parm.tensor_type); - if (DEVICE_MEMORY < used_memory) { - MS_LOG(EXCEPTION) << "Failure: Out of memory!"; - return FAILED; - } } } + if (device_memory < used_memory) { + MS_LOG(EXCEPTION) << "Failure: Out of memory!"; + return FAILED; + } return SUCCESS; } diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h index 4f831f4f9a..e22b11542a 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_partition.h @@ -40,7 +40,7 @@ StrategyRec PartitionNode(const Graph::NodeType &node, const std::vector> &node_name_to_strategy, std::shared_ptr graph); -Status PartitionForAllDevices(const size_t num_device, std::shared_ptr graph); +Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr graph); Graph::NodeType ApplyStrToTensor(Graph::NodeType Node); @@ -50,7 +50,7 @@ void ApplyLastStrategy(const uint64_t node_index, std::shared_ptr graph); void ApplyNextStrategy(const uint64_t node_index, std::shared_ptr graph); -Status DevicesMemoryControl(std::shared_ptr graph); +Status DevicesMemoryControl(const double device_memory, std::shared_ptr graph); size_t GetDataTypeSize(const TensorType &type); } // namespace parallel diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.h b/mindspore/ccsrc/parallel/ops_info/operator_info.h index 347da7e573..de95bd84ad 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.h @@ -150,14 +150,11 @@ class OperatorInfo { // needed by rec_parser void set_type(const std::string &type) { type_ = type; } const std::string &type() const { return type_; } - void set_cnode_name(const std::string &cnode_name) { cnode_name_ = cnode_name; } - const std::string &cnode_name() const { return cnode_name_; } const std::unordered_map &attrs() const { return attrs_; } protected: // needed by rec_parser std::string type_; - std::string cnode_name_; virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; virtual Status InferTensorMap() = 0; virtual Status InferForwardCommunication() = 0; diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 8a95232aa4..7d37bafe98 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -935,7 +935,8 @@ Status ParallelStrategyRecSearch(const std::vector &all_nodes, const std::shared_ptr graph = ParseGraph(ops, input_tensor_names); size_t num_device = g_device_manager->DeviceNum(); - if (PartitionForAllDevices(num_device, graph) == SUCCESS) { + double device_memory = entire_costgraph->GetDeviceMemory(); + if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) { MS_LOG(INFO) << "Partition Success With " << num_device << " devices."; } else { MS_LOG(ERROR) << "PartitionForAllDevices failed."; diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index d1390db899..08f4c56d9f 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -2263,13 +2263,10 @@ std::vector ExtractInputsTensorName(const CNodePtr &node) { std::vector all_inputs = node->inputs(); std::vector node_inputs{all_inputs.begin() + 1, all_inputs.end()}; + std::string node_id = node->UniqueId(); + name_inputs.push_back(node_id); for (auto &input : node_inputs) { - std::string name; - if (IsValueNode(input) || input->isa() || input->isa()) { - name = input->ToString(); - } else { - continue; - } + std::string name = input->UniqueId(); name_inputs.push_back(name); } diff --git a/tests/ut/cpp/parallel/auto_parallel/rec_partition_test.cc b/tests/ut/cpp/parallel/auto_parallel/rec_partition_test.cc index 509b00f428..1eb65b468f 100644 --- a/tests/ut/cpp/parallel/auto_parallel/rec_partition_test.cc +++ b/tests/ut/cpp/parallel/auto_parallel/rec_partition_test.cc @@ -227,19 +227,22 @@ TEST_F(TestPartition, test_PartitionNode) { TEST_F(TestPartition, test_PartitionForAllDevices) { std::shared_ptr graph = MakeMatMulData(9); - ASSERT_EQ(PartitionForAllDevices(1024, graph), SUCCESS); + double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0; + ASSERT_EQ(PartitionForAllDevices(1024, device_memory, graph), SUCCESS); } TEST_F(TestPartition, test_PartitionForAllDevices2) { std::shared_ptr graph = MakeMatMulData(9); - ASSERT_EQ(PartitionForAllDevices(2, graph), SUCCESS); + double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0; + ASSERT_EQ(PartitionForAllDevices(2, device_memory, graph), SUCCESS); } // Negative case: parition on 0 device TEST_F(TestPartition, test_PartitionForAllDevices0) { std::shared_ptr graph = MakeMatMulData(9); + double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0; // Throw Exception "Number of devices can't be 0" - EXPECT_ANY_THROW(PartitionForAllDevices(0, graph)); + EXPECT_ANY_THROW(PartitionForAllDevices(0, device_memory, graph)); } TEST_F(TestPartition, test_ApplyStrToTensor) {