|
|
|
@ -232,7 +232,7 @@ Status PartitionForAllDevices(const size_t num_device, const double device_memor
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (DevicesMemoryControl(device_memory, graph) != SUCCESS) {
|
|
|
|
|
if (DevicesMemoryControl(num_device, device_memory, graph) != SUCCESS) {
|
|
|
|
|
return FAILED;
|
|
|
|
|
} else {
|
|
|
|
|
return SUCCESS;
|
|
|
|
@ -257,16 +257,15 @@ Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) {
|
|
|
|
|
return Node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> graph) {
|
|
|
|
|
Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
|
|
|
|
|
uint64_t iter_nodes = graph->nodes.size();
|
|
|
|
|
double used_memory = 0.0;
|
|
|
|
|
|
|
|
|
|
for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) {
|
|
|
|
|
if (graph->nodes[i_node].info == 0) {
|
|
|
|
|
Graph::NodeType &Node = graph->nodes[i_node];
|
|
|
|
|
double used_memory = 0.0;
|
|
|
|
|
|
|
|
|
|
for (int index = 0; index < 2; index++) {
|
|
|
|
|
used_memory += Node.apply.arguments[index].tensor_str.str_n * Node.apply.arguments[index].tensor_shape.shape_n *
|
|
|
|
|
Node.apply.arguments[index].tensor_str.str_c * Node.apply.arguments[index].tensor_shape.shape_c *
|
|
|
|
@ -274,21 +273,15 @@ Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> g
|
|
|
|
|
Node.apply.arguments[index].tensor_str.str_w * Node.apply.arguments[index].tensor_shape.shape_w *
|
|
|
|
|
GetDataTypeSize(Node.apply.arguments[index].tensor_type);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
used_memory += Node.tensor_parm.tensor_str.str_n * Node.tensor_parm.tensor_shape.shape_n *
|
|
|
|
|
Node.tensor_parm.tensor_str.str_c * Node.tensor_parm.tensor_shape.shape_c *
|
|
|
|
|
Node.tensor_parm.tensor_str.str_h * Node.tensor_parm.tensor_shape.shape_h *
|
|
|
|
|
Node.tensor_parm.tensor_str.str_w * Node.tensor_parm.tensor_shape.shape_w *
|
|
|
|
|
GetDataTypeSize(Node.tensor_parm.tensor_type);
|
|
|
|
|
|
|
|
|
|
if (device_memory < used_memory) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: Out of memory!";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
if (device_memory < (used_memory / num_device)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: Out of memory!";
|
|
|
|
|
return FAILED;
|
|
|
|
|
} else {
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
size_t GetDataTypeSize(const TensorType &type) {
|
|
|
|
|