use DeviceMemory for memory control

pull/556/head
ch-l 5 years ago committed by klchai
parent 3c307cf486
commit f806b72447

@ -29,52 +29,55 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
#define DEVICE_MEMORY 1024.0 * 1024.0 * 1024.0 // 1GB
// Get the target node's weight for sorting. // Get the target node's weight for sorting.
double GetWeights(const Graph::NodeType &node) { double GetWeights(const Graph::NodeType &node) {
const OperatorRec &op = node.apply; const OperatorRec &op = node.apply;
if (op.op_type == 0) { if (op.op_type == OperatorType::kRecMatMul) {
// For MatMul // For MatMul
auto cost_ptr = std::make_shared<CostMatMul>(); auto cost_ptr = std::make_shared<CostMatMul>();
return cost_ptr->GetMinCostIn(op); return cost_ptr->GetMinCostIn(op);
} else if (op.op_type == 1) { } else if (op.op_type == OperatorType::kRecConvolution) {
// For Convolution // For Convolution
auto cost_ptr = std::make_shared<CostConvolution>(); auto cost_ptr = std::make_shared<CostConvolution>();
return cost_ptr->GetMinCostIn(node); return cost_ptr->GetMinCostIn(node);
} else if (op.op_type == 2) { } else if (op.op_type == OperatorType::kRecPooling) {
// For Pooling // For Pooling
auto cost_ptr = std::make_shared<CostPooling>(); auto cost_ptr = std::make_shared<CostPooling>();
return cost_ptr->GetMinCostIn(); return cost_ptr->GetMinCostIn();
} else if (op.op_type == 3) { } else if (op.op_type == OperatorType::kRecAdd) {
// For Add // For Add
auto cost_ptr = std::make_shared<CostAdd>(); auto cost_ptr = std::make_shared<CostAdd>();
return cost_ptr->GetMinCostIn(); return cost_ptr->GetMinCostIn();
} else if (op.op_type == 4 || op.op_type == 7 || op.op_type == 9) { } else if (op.op_type == OperatorType::kRecSoftmax || op.op_type == OperatorType::kRecReLU ||
op.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) {
// For Softmax & || Activation // For Softmax & || Activation
auto cost_ptr = std::make_shared<CostCommon>(); auto cost_ptr = std::make_shared<CostCommon>();
return cost_ptr->GetMinCostIn(); return cost_ptr->GetMinCostIn();
} else if (op.op_type == 5) { } else if (op.op_type == OperatorType::kRecReshape) {
// For Reshape // For Reshape
auto cost_ptr = std::make_shared<CostReshape>(); auto cost_ptr = std::make_shared<CostReshape>();
return cost_ptr->GetMinCostIn(); return cost_ptr->GetMinCostIn();
} else if (op.op_type == 6) { } else if (op.op_type == OperatorType::kRecBiasAdd) {
// For BiasAdd // For BiasAdd
auto cost_ptr = std::make_shared<CostBiasAdd>(); auto cost_ptr = std::make_shared<CostBiasAdd>();
return cost_ptr->GetMinCostIn(); return cost_ptr->GetMinCostIn();
} else if (op.op_type == 8) { } else if (op.op_type == OperatorType::kRecBatchNorm) {
// For BatchNorm // For BatchNorm
auto cost_ptr = std::make_shared<CostBatchNorm>(); auto cost_ptr = std::make_shared<CostBatchNorm>();
return cost_ptr->GetMinCostIn(); return cost_ptr->GetMinCostIn();
} else if (op.op_type == OperatorType::kRecUnkownType) {
// For unknown type
return 0.0;
} else { } else {
MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed."; MS_LOG(EXCEPTION) << "Failure: GetOperatorWeight failed.";
} }
@ -155,13 +158,17 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
auto cost_ptr = std::make_shared<CostBatchNorm>(); auto cost_ptr = std::make_shared<CostBatchNorm>();
return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph); return cost_ptr->GetOptimalStr(node, node_name_to_strategy, *graph);
} else if (node.apply.op_type == 10) {
// For unknown type
StrategyRec default_strategy;
return default_strategy;
} else { } else {
MS_LOG(EXCEPTION) << "Failure: Partition Operator failed."; MS_LOG(EXCEPTION) << "Failure: Partition Operator failed.";
} }
} }
// Parttion graph into all devices. // Parttion graph into all devices.
Status PartitionForAllDevices(const size_t num_device, std::shared_ptr<Graph> graph) { Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph) {
if (num_device < 1) { if (num_device < 1) {
MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << "."; MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << ".";
} }
@ -207,7 +214,7 @@ Status PartitionForAllDevices(const size_t num_device, std::shared_ptr<Graph> gr
} }
InferUndecideStrategy(graph); InferUndecideStrategy(graph);
if (DevicesMemoryControl(graph) != SUCCESS) { if (DevicesMemoryControl(device_memory, graph) != SUCCESS) {
return FAILED; return FAILED;
} else { } else {
return SUCCESS; return SUCCESS;
@ -306,15 +313,15 @@ void ApplyNextStrategy(const uint64_t node_index, std::shared_ptr<Graph> graph)
} }
} }
Status DevicesMemoryControl(std::shared_ptr<Graph> graph) { Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
uint64_t iter_nodes = graph->nodes.size(); uint64_t iter_nodes = graph->nodes.size();
double used_memory = 0.0;
for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) { for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) {
if (graph->nodes[i_node].info == 0) { if (graph->nodes[i_node].info == 0) {
Graph::NodeType &Node = graph->nodes[i_node]; Graph::NodeType &Node = graph->nodes[i_node];
double used_memory = 0.0;
for (int index = 0; index < 2; index++) { for (int index = 0; index < 2; index++) {
used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n * used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n *
@ -329,12 +336,12 @@ Status DevicesMemoryControl(std::shared_ptr<Graph> graph) {
Node.tensor_parm.tensor_str.str_h * Node.tensor_parm.tensor_shape.shape_h * Node.tensor_parm.tensor_str.str_h * Node.tensor_parm.tensor_shape.shape_h *
Node.tensor_parm.tensor_str.str_w * Node.tensor_parm.tensor_shape.shape_w * Node.tensor_parm.tensor_str.str_w * Node.tensor_parm.tensor_shape.shape_w *
GetDataTypeSize(Node.tensor_parm.tensor_type); GetDataTypeSize(Node.tensor_parm.tensor_type);
if (DEVICE_MEMORY < used_memory) {
MS_LOG(EXCEPTION) << "Failure: Out of memory!";
return FAILED;
}
} }
} }
if (device_memory < used_memory) {
MS_LOG(EXCEPTION) << "Failure: Out of memory!";
return FAILED;
}
return SUCCESS; return SUCCESS;
} }

@ -40,7 +40,7 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy, const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
std::shared_ptr<Graph> graph); std::shared_ptr<Graph> graph);
Status PartitionForAllDevices(const size_t num_device, std::shared_ptr<Graph> graph); Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph);
Graph::NodeType ApplyStrToTensor(Graph::NodeType Node); Graph::NodeType ApplyStrToTensor(Graph::NodeType Node);
@ -50,7 +50,7 @@ void ApplyLastStrategy(const uint64_t node_index, std::shared_ptr<Graph> graph);
void ApplyNextStrategy(const uint64_t node_index, std::shared_ptr<Graph> graph); void ApplyNextStrategy(const uint64_t node_index, std::shared_ptr<Graph> graph);
Status DevicesMemoryControl(std::shared_ptr<Graph> graph); Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> graph);
size_t GetDataTypeSize(const TensorType &type); size_t GetDataTypeSize(const TensorType &type);
} // namespace parallel } // namespace parallel

@ -150,14 +150,11 @@ class OperatorInfo {
// needed by rec_parser // needed by rec_parser
void set_type(const std::string &type) { type_ = type; } void set_type(const std::string &type) { type_ = type; }
const std::string &type() const { return type_; } const std::string &type() const { return type_; }
void set_cnode_name(const std::string &cnode_name) { cnode_name_ = cnode_name; }
const std::string &cnode_name() const { return cnode_name_; }
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; } const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
protected: protected:
// needed by rec_parser // needed by rec_parser
std::string type_; std::string type_;
std::string cnode_name_;
virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; virtual Status CheckStrategy(const StrategyPtr &strategy) = 0;
virtual Status InferTensorMap() = 0; virtual Status InferTensorMap() = 0;
virtual Status InferForwardCommunication() = 0; virtual Status InferForwardCommunication() = 0;

@ -935,7 +935,8 @@ Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const
std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names); std::shared_ptr<Graph> graph = ParseGraph(ops, input_tensor_names);
size_t num_device = g_device_manager->DeviceNum(); size_t num_device = g_device_manager->DeviceNum();
if (PartitionForAllDevices(num_device, graph) == SUCCESS) { double device_memory = entire_costgraph->GetDeviceMemory();
if (PartitionForAllDevices(num_device, device_memory, graph) == SUCCESS) {
MS_LOG(INFO) << "Partition Success With " << num_device << " devices."; MS_LOG(INFO) << "Partition Success With " << num_device << " devices.";
} else { } else {
MS_LOG(ERROR) << "PartitionForAllDevices failed."; MS_LOG(ERROR) << "PartitionForAllDevices failed.";

@ -2263,13 +2263,10 @@ std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node) {
std::vector<AnfNodePtr> all_inputs = node->inputs(); std::vector<AnfNodePtr> all_inputs = node->inputs();
std::vector<AnfNodePtr> node_inputs{all_inputs.begin() + 1, all_inputs.end()}; std::vector<AnfNodePtr> node_inputs{all_inputs.begin() + 1, all_inputs.end()};
std::string node_id = node->UniqueId();
name_inputs.push_back(node_id);
for (auto &input : node_inputs) { for (auto &input : node_inputs) {
std::string name; std::string name = input->UniqueId();
if (IsValueNode<Tensor>(input) || input->isa<CNode>() || input->isa<Parameter>()) {
name = input->ToString();
} else {
continue;
}
name_inputs.push_back(name); name_inputs.push_back(name);
} }

@ -227,19 +227,22 @@ TEST_F(TestPartition, test_PartitionNode) {
TEST_F(TestPartition, test_PartitionForAllDevices) { TEST_F(TestPartition, test_PartitionForAllDevices) {
std::shared_ptr<Graph> graph = MakeMatMulData(9); std::shared_ptr<Graph> graph = MakeMatMulData(9);
ASSERT_EQ(PartitionForAllDevices(1024, graph), SUCCESS); double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0;
ASSERT_EQ(PartitionForAllDevices(1024, device_memory, graph), SUCCESS);
} }
TEST_F(TestPartition, test_PartitionForAllDevices2) { TEST_F(TestPartition, test_PartitionForAllDevices2) {
std::shared_ptr<Graph> graph = MakeMatMulData(9); std::shared_ptr<Graph> graph = MakeMatMulData(9);
ASSERT_EQ(PartitionForAllDevices(2, graph), SUCCESS); double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0;
ASSERT_EQ(PartitionForAllDevices(2, device_memory, graph), SUCCESS);
} }
// Negative case: parition on 0 device // Negative case: parition on 0 device
TEST_F(TestPartition, test_PartitionForAllDevices0) { TEST_F(TestPartition, test_PartitionForAllDevices0) {
std::shared_ptr<Graph> graph = MakeMatMulData(9); std::shared_ptr<Graph> graph = MakeMatMulData(9);
double device_memory = 1024.0 * 1024.0 * 1024.0 * 16.0;
// Throw Exception "Number of devices can't be 0" // Throw Exception "Number of devices can't be 0"
EXPECT_ANY_THROW(PartitionForAllDevices(0, graph)); EXPECT_ANY_THROW(PartitionForAllDevices(0, device_memory, graph));
} }
TEST_F(TestPartition, test_ApplyStrToTensor) { TEST_F(TestPartition, test_ApplyStrToTensor) {

Loading…
Cancel
Save