|
|
|
@ -527,6 +527,10 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
|
|
|
|
|
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
|
|
|
|
|
<< " does not match the Prim: " << prim->name();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Needed by rec_parser
|
|
|
|
|
ModifyInputsTensorNameListIfOperatorInfoCreated(current_op_ptr->name(), cnode->UniqueId());
|
|
|
|
|
|
|
|
|
|
cnode->set_user_data<OperatorInfo>(current_op_ptr);
|
|
|
|
|
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
|
|
|
|
|
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
|
|
|
|
@ -1124,6 +1128,27 @@ CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, const std::string &uniqueid) {
|
|
|
|
|
size_t iter_ops = 0;
|
|
|
|
|
for (auto op : entire_costgraph->GetOperators()) {
|
|
|
|
|
if (op->name() == name) {
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
iter_ops = iter_ops + 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
|
|
|
|
|
for (size_t i = 0; i < input_tensor_names.size(); i++) {
|
|
|
|
|
for (size_t j = 0; j < input_tensor_names[i].size(); j++) {
|
|
|
|
|
if (input_tensor_names[i][j] == uniqueid) {
|
|
|
|
|
input_tensor_names[i][j] = input_tensor_names[iter_ops][0];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
entire_costgraph->set_inputs_tensor_name_list(input_tensor_names);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
|
|
|
|
|
if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
|
|
|
|
|
if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {
|
|
|
|
|