|
|
@ -73,7 +73,7 @@ double GetWeights(const Graph::NodeType &node) {
|
|
|
|
// For BatchNorm
|
|
|
|
// For BatchNorm
|
|
|
|
auto cost_ptr = std::make_shared<CostBatchNorm>();
|
|
|
|
auto cost_ptr = std::make_shared<CostBatchNorm>();
|
|
|
|
|
|
|
|
|
|
|
|
return cost_ptr->GetMinCostIn();
|
|
|
|
return cost_ptr->GetMinCostIn(op);
|
|
|
|
} else if (op.op_type == OperatorType::kRecOneHot || op.op_type == OperatorType::kRecLog ||
|
|
|
|
} else if (op.op_type == OperatorType::kRecOneHot || op.op_type == OperatorType::kRecLog ||
|
|
|
|
op.op_type == OperatorType::kRecExp || op.op_type == OperatorType::kRecAdd ||
|
|
|
|
op.op_type == OperatorType::kRecExp || op.op_type == OperatorType::kRecAdd ||
|
|
|
|
op.op_type == OperatorType::kRecSub || op.op_type == OperatorType::kRecMul ||
|
|
|
|
op.op_type == OperatorType::kRecSub || op.op_type == OperatorType::kRecMul ||
|
|
|
@ -108,8 +108,8 @@ std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> graph) {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Do sorting.
|
|
|
|
// Ordering ops aka nodes of the graph
|
|
|
|
sort(weight_to_node_index.begin(), weight_to_node_index.end());
|
|
|
|
std::sort(weight_to_node_index.begin(), weight_to_node_index.end());
|
|
|
|
|
|
|
|
|
|
|
|
// Store the result in node_index_by_weights.
|
|
|
|
// Store the result in node_index_by_weights.
|
|
|
|
uint64_t size = weight_to_node_index.size();
|
|
|
|
uint64_t size = weight_to_node_index.size();
|
|
|
@ -231,7 +231,6 @@ Status PartitionForAllDevices(const size_t num_device, const double device_memor
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
InferUndecideStrategy(graph);
|
|
|
|
|
|
|
|
if (DevicesMemoryControl(device_memory, graph) != SUCCESS) {
|
|
|
|
if (DevicesMemoryControl(device_memory, graph) != SUCCESS) {
|
|
|
|
return FAILED;
|
|
|
|
return FAILED;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
@ -257,80 +256,6 @@ Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) {
|
|
|
|
return Node;
|
|
|
|
return Node;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Check Strategy for the same tensor between op.
|
|
|
|
|
|
|
|
void InferUndecideStrategy(std::shared_ptr<Graph> graph) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uint64_t iter_nodes = graph->nodes.size();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// For all the nodes in the graph
|
|
|
|
|
|
|
|
for (uint64_t i_node = 0; i_node < iter_nodes; i_node++) {
|
|
|
|
|
|
|
|
// If this target node is an operator, find it's adjecent op's strategy;
|
|
|
|
|
|
|
|
if (graph->nodes[i_node].info == 0) {
|
|
|
|
|
|
|
|
// Try to apply last op's strategy.
|
|
|
|
|
|
|
|
ApplyLastStrategy(i_node, graph);
|
|
|
|
|
|
|
|
// Try to apply next op's strategy.
|
|
|
|
|
|
|
|
ApplyNextStrategy(i_node, graph);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void ApplyLastStrategy(const uint64_t node_index, std::shared_ptr<Graph> graph) {
|
|
|
|
|
|
|
|
Graph::NodeType &target_node = graph->nodes[node_index];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Number of node-in
|
|
|
|
|
|
|
|
size_t num_node_in = target_node.node_in.size();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Find forward op and copy strategy if meets the limits.
|
|
|
|
|
|
|
|
for (size_t index = 0; index < num_node_in; index++) {
|
|
|
|
|
|
|
|
if (graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_n <=
|
|
|
|
|
|
|
|
target_node.apply.arguments[0].tensor_str.str_n &&
|
|
|
|
|
|
|
|
graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_c <=
|
|
|
|
|
|
|
|
target_node.apply.arguments[0].tensor_str.str_c &&
|
|
|
|
|
|
|
|
graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_h <=
|
|
|
|
|
|
|
|
target_node.apply.arguments[0].tensor_str.str_h &&
|
|
|
|
|
|
|
|
graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_w <=
|
|
|
|
|
|
|
|
target_node.apply.arguments[0].tensor_str.str_w) {
|
|
|
|
|
|
|
|
target_node.apply.arguments[0].tensor_str.str_n =
|
|
|
|
|
|
|
|
graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_n;
|
|
|
|
|
|
|
|
target_node.apply.arguments[0].tensor_str.str_c =
|
|
|
|
|
|
|
|
graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_c;
|
|
|
|
|
|
|
|
target_node.apply.arguments[0].tensor_str.str_h =
|
|
|
|
|
|
|
|
graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_h;
|
|
|
|
|
|
|
|
target_node.apply.arguments[0].tensor_str.str_w =
|
|
|
|
|
|
|
|
graph->nodes[target_node.node_in[index]].tensor_parm.tensor_str.str_w;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void ApplyNextStrategy(const uint64_t node_index, std::shared_ptr<Graph> graph) {
|
|
|
|
|
|
|
|
Graph::NodeType &target_node = graph->nodes[node_index];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Number of node-out
|
|
|
|
|
|
|
|
size_t num_node_out = target_node.node_out.size();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Find backward op and copy strategy if meets the limits.
|
|
|
|
|
|
|
|
for (size_t index = 0; index < num_node_out; index++) {
|
|
|
|
|
|
|
|
if (graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_n <=
|
|
|
|
|
|
|
|
target_node.tensor_parm.tensor_str.str_n &&
|
|
|
|
|
|
|
|
graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_c <=
|
|
|
|
|
|
|
|
target_node.tensor_parm.tensor_str.str_c &&
|
|
|
|
|
|
|
|
graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_h <=
|
|
|
|
|
|
|
|
target_node.tensor_parm.tensor_str.str_h &&
|
|
|
|
|
|
|
|
graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_w <=
|
|
|
|
|
|
|
|
target_node.tensor_parm.tensor_str.str_w) {
|
|
|
|
|
|
|
|
target_node.tensor_parm.tensor_str.str_n =
|
|
|
|
|
|
|
|
graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_n;
|
|
|
|
|
|
|
|
target_node.tensor_parm.tensor_str.str_c =
|
|
|
|
|
|
|
|
graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_c;
|
|
|
|
|
|
|
|
target_node.tensor_parm.tensor_str.str_h =
|
|
|
|
|
|
|
|
graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_h;
|
|
|
|
|
|
|
|
target_node.tensor_parm.tensor_str.str_w =
|
|
|
|
|
|
|
|
graph->nodes[target_node.node_out[index]].apply.arguments[0].tensor_str.str_w;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> graph) {
|
|
|
|
Status DevicesMemoryControl(const double device_memory, std::shared_ptr<Graph> graph) {
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
|
|
|
|
|
|
|