|
|
@ -29,52 +29,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
namespace mindspore {
|
|
|
|
namespace parallel {
|
|
|
|
namespace parallel {
|
|
|
|
#define DEVICE_MEMORY 1024.0 * 1024.0 * 1024.0 // 1GB
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Get the target node's weight for sorting.
|
|
|
|
// Get the target node's weight for sorting.
|
|
|
|
double GetWeights(const Graph::NodeType &node) {
|
|
|
|
double GetWeights(const Graph::NodeType &node) {
|
|
|
|
const OperatorRec &op = node.apply;
|
|
|
|
const OperatorRec &op = node.apply;
|
|
|
|
|
|
|
|
|
|
|
|
if (op.op_type == 0) {
|
|
|
|
if (op.op_type == OperatorType::kRecMatMul) {
|
|
|
|
// For MatMul
|
|
|
|
// For MatMul
|
|
|
|
auto cost_ptr = std::make_shared<CostMatMul>();
|
|
|
|
auto cost_ptr = std::make_shared<CostMatMul>();
|
|
|
|
|
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn(op);
|
|
|
|
return cost_ptr->GetMinCostIn(op);
|
|
|
|
} else if (op.op_type == 1) {
|
|
|
|
} else if (op.op_type == OperatorType::kRecConvolution) {
|
|
|
|
// For Convolution
|
|
|
|
// For Convolution
|
|
|
|
auto cost_ptr = std::make_shared<CostConvolution>();
|
|
|
|
auto cost_ptr = std::make_shared<CostConvolution>();
|
|
|
|
|
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn(node);
|
|
|
|
return cost_ptr->GetMinCostIn(node);
|
|
|
|
} else if (op.op_type == 2) {
|
|
|
|
} else if (op.op_type == OperatorType::kRecPooling) {
|
|
|
|
// For Pooling
|
|
|
|
// For Pooling
|
|
|
|
auto cost_ptr = std::make_shared<CostPooling>();
|
|
|
|
auto cost_ptr = std::make_shared<CostPooling>();
|
|
|
|
|
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn();
|
|
|
|
return cost_ptr->GetMinCostIn();
|
|
|
|
} else if (op.op_type == 3) {
|
|
|
|
} else if (op.op_type == OperatorType::kRecAdd) {
|
|
|
|
// For Add
|
|
|
|
// For Add
|
|
|
|
auto cost_ptr = std::make_shared<CostAdd>();
|
|
|
|
auto cost_ptr = std::make_shared<CostAdd>();
|
|
|
|
|
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn();
|
|
|
|
return cost_ptr->GetMinCostIn();
|
|
|
|
} else if (op.op_type == 4 || op.op_type == 7 || op.op_type == 9) {
|
|
|
|
} else if (op.op_type == OperatorType::kRecSoftmax || op.op_type == OperatorType::kRecReLU ||
|
|
|
|
|
|
|
|
op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) {
|
|
|
|
// For Softmax & || Activation
|
|
|
|
// For Softmax & || Activation
|
|
|
|
auto cost_ptr = std::make_shared<CostCommon>();
|
|
|
|
auto cost_ptr = std::make_shared<CostCommon>();
|
|
|
|
|
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn();
|
|
|
|
return cost_ptr->GetMinCostIn();
|
|
|
|
} else if (op.op_type == 5) {
|
|
|
|
} else if (op.op_type == OperatorType::kRecReshape) {
|
|
|
|
// For Reshape
|
|
|
|
// For Reshape
|
|
|
|
auto cost_ptr = std::make_shared<CostReshape>();
|
|
|
|
auto cost_ptr = std::make_shared<CostReshape>();
|
|
|
|
|
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn();
|
|
|
|
return cost_ptr->GetMinCostIn();
|
|
|
|
} else if (op.op_type == 6) {
|
|
|
|
} else if (op.op_type == OperatorType::kRecBiasAdd) {
|
|
|
|
// For BiasAdd
|
|
|
|
// For BiasAdd
|
|
|
|
auto cost_ptr = std::make_shared<CostBiasAdd>();
|
|
|
|
auto cost_ptr = std::make_shared<CostBiasAdd>();
|
|
|
|
|
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn();
|
|
|
|
return cost_ptr->GetMinCostIn();
|
|
|
|
} else if (op.op_type == 8) {
|
|
|
|
} else if (op.op_type == OperatorType::kRecBatchNorm) {
|
|
|
|
// For BatchNorm
|
|
|
|
// For BatchNorm
|
|
|
|
auto cost_ptr = std::make_shared<CostBatchNorm>();
|
|
|
|
auto cost_ptr = std::make_shared<CostBatchNorm>();
|
|
|
|
|
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn();
|
|
|
|
return cost_ptr->GetMinCostIn();
|
|
|
|
|
|
|
|
} else if (op.op_type == OperatorType::kRecUnkownType) {
|
|
|
|
|
|
|
|
// For unknown type
|
|
|
|
|
|
|
|
return 0.0;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed.";
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed.";
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -155,13 +158,17 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
|
|
|
|
auto cost_ptr = std::make_shared<CostBatchNorm>();
|
|
|
|
auto cost_ptr = std::make_shared<CostBatchNorm>();
|
|
|
|
|
|
|
|
|
|
|
|
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
|
|
|
|
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
|
|
|
|
|
|
|
|
} else if (node.apply.op_type == 10) {
|
|
|
|
|
|
|
|
// For unknown type
|
|
|
|
|
|
|
|
StrategyRec default_strategy;
|
|
|
|
|
|
|
|
return default_strategy;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: Partition Operator failed.";
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: Partition Operator failed.";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Parttion graph into all devices.
|
|
|
|
// Parttion graph into all devices.
|
|
|
|
Status PartitionForAllDevices(const size_t num_device, std::shared_ptr<Graph> graph) {
|
|
|
|
Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph) {
|
|
|
|
if (num_device < 1) {
|
|
|
|
if (num_device < 1) {
|
|
|
|
MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << ".";
|
|
|
|
MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << ".";
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -207,7 +214,7 @@ Status PartitionForAllDevices(const size_t num_device, std::shared_ptr<Graph> gr
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
InferUndecideStrategy(graph);
|
|
|
|
InferUndecideStrategy(graph);
|
|
|
|
if (DevicesMemoryControl(graph) != SUCCESS) {
|
|
|
|
if (DevicesMemoryControl(device_memory, graph) != SUCCESS) {
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
return SUCCESS;
|
|
|
|
return SUCCESS;
|
|
|
@ -306,15 +313,15 @@ void ApplyNextStrategy(const uint64_t node_index, std::shared_ptr<Graph> graph)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Status DevicesMemoryControl(std::shared_ptr<Graph> graph) {
|
|
|
|
Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> graph) {
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
|
|
|
|
|
|
|
uint64_t iter_nodes = graph->nodes.size();
|
|
|
|
uint64_t iter_nodes = graph->nodes.size();
|
|
|
|
|
|
|
|
double used_memory = 0.0;
|
|
|
|
|
|
|
|
|
|
|
|
for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) {
|
|
|
|
for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) {
|
|
|
|
if (graph->nodes[i_node].info == 0) {
|
|
|
|
if (graph->nodes[i_node].info == 0) {
|
|
|
|
Graph::NodeType &Node = graph->nodes[i_node];
|
|
|
|
Graph::NodeType &Node = graph->nodes[i_node];
|
|
|
|
double used_memory = 0.0;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int index = 0; index < 2; index++) {
|
|
|
|
for (int index = 0; index < 2; index++) {
|
|
|
|
used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n *
|
|
|
|
used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n *
|
|
|
@ -329,11 +336,11 @@ Status DevicesMemoryControl(std::shared_ptr<Graph> graph) {
|
|
|
|
Node.tensor_parm.tensor_str.str_h * Node.tensor_parm.tensor_shape.shape_h *
|
|
|
|
Node.tensor_parm.tensor_str.str_h * Node.tensor_parm.tensor_shape.shape_h *
|
|
|
|
Node.tensor_parm.tensor_str.str_w * Node.tensor_parm.tensor_shape.shape_w *
|
|
|
|
Node.tensor_parm.tensor_str.str_w * Node.tensor_parm.tensor_shape.shape_w *
|
|
|
|
GetDataTypeSize(Node.tensor_parm.tensor_type);
|
|
|
|
GetDataTypeSize(Node.tensor_parm.tensor_type);
|
|
|
|
if (DEVICE_MEMORY < used_memory) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: Out of memory!";
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if (device_memory < used_memory) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: Out of memory!";
|
|
|
|
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
return SUCCESS;
|
|
|
|