|
|
@ -677,13 +677,13 @@ void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &ar
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
MS_EXCEPTION_IF_NULL(arg);
|
|
|
|
MS_EXCEPTION_IF_NULL(arg);
|
|
|
|
if (real_inputs_.find(parameter) == real_inputs_.end()) {
|
|
|
|
if (real_inputs_.find(parameter) == real_inputs_.end()) {
|
|
|
|
real_inputs_[parameter] = std::set<AnfNodePtr>();
|
|
|
|
real_inputs_[parameter] = std::vector<AnfNodePtr>();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto &args = real_inputs_[parameter];
|
|
|
|
auto &args = real_inputs_[parameter];
|
|
|
|
(void)args.insert(arg);
|
|
|
|
(void)args.push_back(arg);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr ¶meter) {
|
|
|
|
std::vector<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr ¶meter) {
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
auto iter = real_inputs_.find(parameter);
|
|
|
|
auto iter = real_inputs_.find(parameter);
|
|
|
|
if (iter != real_inputs_.end()) {
|
|
|
|
if (iter != real_inputs_.end()) {
|
|
|
@ -694,7 +694,7 @@ std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr ¶meter) {
|
|
|
|
|
|
|
|
|
|
|
|
void KernelGraph::UpdateCallRealInput() {
|
|
|
|
void KernelGraph::UpdateCallRealInput() {
|
|
|
|
MS_LOG(INFO) << "Update graph id: " << graph_id_;
|
|
|
|
MS_LOG(INFO) << "Update graph id: " << graph_id_;
|
|
|
|
std::map<AnfNodePtr, std::set<AnfNodePtr>> real_inputs_map;
|
|
|
|
std::map<AnfNodePtr, std::vector<AnfNodePtr>> real_inputs_map;
|
|
|
|
for (auto &it : real_inputs_) {
|
|
|
|
for (auto &it : real_inputs_) {
|
|
|
|
auto parameter = it.first;
|
|
|
|
auto parameter = it.first;
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
@ -713,12 +713,18 @@ void KernelGraph::UpdateCallRealInput() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto &erase_node : erase_real_inputs) {
|
|
|
|
for (auto &erase_node : erase_real_inputs) {
|
|
|
|
MS_LOG(INFO) << "paramter: " << parameter->DebugString() << " erase real input:" << erase_node->DebugString();
|
|
|
|
MS_LOG(INFO) << "paramter: " << parameter->DebugString() << " erase real input:" << erase_node->DebugString();
|
|
|
|
(void)real_inputs.erase(erase_node);
|
|
|
|
for (auto iter = real_inputs.begin(); iter != real_inputs.end();) {
|
|
|
|
|
|
|
|
if (*iter == erase_node) {
|
|
|
|
|
|
|
|
iter = real_inputs.erase(iter);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
++iter;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto &new_real_input : new_real_inputs) {
|
|
|
|
for (auto &new_real_input : new_real_inputs) {
|
|
|
|
MS_LOG(INFO) << "paramter: " << parameter->DebugString()
|
|
|
|
MS_LOG(INFO) << "paramter: " << parameter->DebugString()
|
|
|
|
<< " insert real input:" << new_real_input->DebugString();
|
|
|
|
<< " insert real input:" << new_real_input->DebugString();
|
|
|
|
(void)real_inputs.insert(new_real_input);
|
|
|
|
(void)real_inputs.push_back(new_real_input);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
real_inputs_map[parameter] = real_inputs;
|
|
|
|
real_inputs_map[parameter] = real_inputs;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -730,18 +736,28 @@ void KernelGraph::PrintGraphExecuteOrder() const {
|
|
|
|
for (size_t i = 0; i < execution_order_.size(); i++) {
|
|
|
|
for (size_t i = 0; i < execution_order_.size(); i++) {
|
|
|
|
CNodePtr cur_cnode_ptr = execution_order_[i];
|
|
|
|
CNodePtr cur_cnode_ptr = execution_order_[i];
|
|
|
|
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
|
|
|
MS_EXCEPTION_IF_NULL(cur_cnode_ptr);
|
|
|
|
if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kSendOpName || AnfAlgo::GetCNodeName(cur_cnode_ptr) == kRecvOpName) {
|
|
|
|
std::string event_str;
|
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr);
|
|
|
|
std::string label_str;
|
|
|
|
MS_LOG(INFO) << "index[" << i << "], node name[" << AnfAlgo::GetCNodeName(cur_cnode_ptr) << "], logic id["
|
|
|
|
if (AnfAlgo::HasNodeAttr(kAttrEventId, cur_cnode_ptr)) {
|
|
|
|
<< AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
|
|
|
|
event_str = ", event_id[" + std::to_string(AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrEventId)) + "]";
|
|
|
|
<< AnfAlgo::GetStreamId(cur_cnode_ptr) << "], event_id["
|
|
|
|
}
|
|
|
|
<< GetValue<uint32_t>(primitive->GetAttr(kAttrEventId)) << "], node info["
|
|
|
|
|
|
|
|
<< cur_cnode_ptr->DebugString() << "]";
|
|
|
|
if (AnfAlgo::HasNodeAttr(kAttrLabelIndex, cur_cnode_ptr)) {
|
|
|
|
} else {
|
|
|
|
label_str = ", label_id[" + std::to_string(AnfAlgo::GetNodeAttr<uint32_t>(cur_cnode_ptr, kAttrLabelIndex)) + "]";
|
|
|
|
MS_LOG(INFO) << "index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
|
|
|
|
|
|
|
|
<< AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
|
|
|
|
|
|
|
|
<< AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]";
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (AnfAlgo::HasNodeAttr(kAttrLabelSwitchList, cur_cnode_ptr)) {
|
|
|
|
|
|
|
|
auto label_list = AnfAlgo::GetNodeAttr<std::vector<uint32_t>>(cur_cnode_ptr, kAttrLabelSwitchList);
|
|
|
|
|
|
|
|
label_str = ", label_id[";
|
|
|
|
|
|
|
|
for (size_t j = 0; j < label_list.size(); ++j) {
|
|
|
|
|
|
|
|
label_str += std::to_string(label_list[j]) + (j + 1 < label_list.size() ? ", " : "]");
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "index[" << i << "], node name[" << cur_cnode_ptr->fullname_with_scope() << "], logic id["
|
|
|
|
|
|
|
|
<< AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) << "], stream id["
|
|
|
|
|
|
|
|
<< AnfAlgo::GetStreamId(cur_cnode_ptr) << "], node info[" << cur_cnode_ptr->DebugString() << "]"
|
|
|
|
|
|
|
|
<< event_str << label_str;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|