!4054 [AutoParallel] Use uniqueid to manage input tensors

Merge pull request !4054 from Chong/wd
pull/4054/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 6657adfaef

@ -197,6 +197,9 @@ class CostGraph {
inputs_tensor_name_list_.push_back(inputs_tensor_name);
}
const std::vector<std::vector<std::string>> get_inputs_tensor_name_list() const { return inputs_tensor_name_list_; }
void set_inputs_tensor_name_list(const std::vector<std::vector<std::string>> &inputs_tensor_name_list) {
inputs_tensor_name_list_ = inputs_tensor_name_list;
}
void add_tuple_getitem(const std::pair<std::string, std::string> &tuple_getitem) {
auto ret = tuple_getitem_list_.insert(tuple_getitem);
if (ret.second == false) {

@ -527,6 +527,10 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
<< " does not match the Prim: " << prim->name();
}
// Needed by rec_parser
ModifyInputsTensorNameListIfOperatorInfoCreated(current_op_ptr->name(), cnode->UniqueId());
cnode->set_user_data<OperatorInfo>(current_op_ptr);
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
@ -1124,6 +1128,27 @@ CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim
return nullptr;
}
void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, const std::string &uniqueid) {
size_t iter_ops = 0;
for (auto op : entire_costgraph->GetOperators()) {
if (op->name() == name) {
break;
}
iter_ops = iter_ops + 1;
}
std::vector<std::vector<std::string>> input_tensor_names = entire_costgraph->get_inputs_tensor_name_list();
for (size_t i = 0; i < input_tensor_names.size(); i++) {
for (size_t j = 0; j < input_tensor_names[i].size(); j++) {
if (input_tensor_names[i][j] == uniqueid) {
input_tensor_names[i][j] = input_tensor_names[iter_ops][0];
}
}
}
entire_costgraph->set_inputs_tensor_name_list(input_tensor_names);
}
Status ParallelStrategyRecSearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
if (CostModelContext::GetInstance()->is_multi_subgraphs()) {
if (ConstructCostGraphNodesByUniqueIdTC(all_nodes, root) == SUCCESS) {

@ -59,6 +59,8 @@ std::vector<std::vector<std::string>> RecInputTensorNames(const std::map<std::st
std::vector<std::vector<std::string>> input_tensor_names);
CNodePtr GetInternalOperatorInfo(const CNodePtr &cnode, const ValueNodePtr &prim_anf_node);
void ModifyInputsTensorNameListIfOperatorInfoCreated(const std::string &name, const std::string &uniqueid);
} // namespace parallel
} // namespace mindspore
#endif // PARALLEL_STEP_AUTO_PARALLEL_H_

Loading…
Cancel
Save