|
|
|
@ -165,6 +165,21 @@ void KernelGraph::SetExecOrderByDefault() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
CheckLoop();
|
|
|
|
|
// resort start label / end goto
|
|
|
|
|
std::vector<CNodePtr> re_order;
|
|
|
|
|
if (start_label_ != nullptr) {
|
|
|
|
|
re_order.push_back(start_label_);
|
|
|
|
|
}
|
|
|
|
|
for (auto &node : execution_order_) {
|
|
|
|
|
if (node == start_label_ || node == end_goto_) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
re_order.push_back(node);
|
|
|
|
|
}
|
|
|
|
|
if (end_goto_ != nullptr) {
|
|
|
|
|
re_order.push_back(end_goto_);
|
|
|
|
|
}
|
|
|
|
|
execution_order_ = re_order;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelGraph::CheckLoop() {
|
|
|
|
@ -360,7 +375,8 @@ void KernelGraph::FrontBackendlMapAdd(const AnfNodePtr &front_anf, const AnfNode
|
|
|
|
|
void KernelGraph::FrontBackendlMapUpdate(const AnfNodePtr &old_backend_anf, const AnfNodePtr &new_backend_anf) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(old_backend_anf);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_backend_anf);
|
|
|
|
|
if (old_backend_anf.get() == new_backend_anf.get()) {
|
|
|
|
|
if (old_backend_anf == new_backend_anf) {
|
|
|
|
|
MS_LOG(INFO) << "old:" << old_backend_anf->DebugString() << ",new:" << new_backend_anf->DebugString();
|
|
|
|
|
MS_LOG(EXCEPTION) << "old can't be same with new";
|
|
|
|
|
}
|
|
|
|
|
if (backend_front_anf_map_.find(old_backend_anf) == backend_front_anf_map_.end()) {
|
|
|
|
@ -569,32 +585,52 @@ void KernelGraph::ReplaceNode(const AnfNodePtr &old_anf_node, AnfNodePtr new_anf
|
|
|
|
|
MS_EXCEPTION_IF_NULL(new_anf_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(inputs_);
|
|
|
|
|
auto it = node_output_edges_.find(old_anf_node);
|
|
|
|
|
if (it == node_output_edges_.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Can't find anf node in node_output_edges map";
|
|
|
|
|
}
|
|
|
|
|
auto &outputs = it->second;
|
|
|
|
|
for (auto &output_node : outputs) {
|
|
|
|
|
auto output_cnode = output_node.first->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_cnode);
|
|
|
|
|
auto &output_node_inputs = output_cnode->inputs();
|
|
|
|
|
for (size_t i = 1; i < output_node_inputs.size(); i++) {
|
|
|
|
|
if (output_node_inputs[i] == old_anf_node) {
|
|
|
|
|
output_cnode->set_input(i, new_anf_node);
|
|
|
|
|
if (it != node_output_edges_.end()) {
|
|
|
|
|
const auto &outputs = it->second;
|
|
|
|
|
for (auto &output_node : outputs) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_node.first);
|
|
|
|
|
auto output_cnode = output_node.first->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_cnode);
|
|
|
|
|
const auto &output_node_inputs = output_cnode->inputs();
|
|
|
|
|
for (size_t i = 1; i < output_node_inputs.size(); i++) {
|
|
|
|
|
if (output_node_inputs[i] == old_anf_node) {
|
|
|
|
|
output_cnode->set_input(i, new_anf_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// update graph inputs
|
|
|
|
|
for (size_t i = 0; i < inputs_->size(); i++) {
|
|
|
|
|
if ((*inputs_)[i] == old_anf_node) {
|
|
|
|
|
(*inputs_)[i] = new_anf_node;
|
|
|
|
|
break;
|
|
|
|
|
// update graph inputs
|
|
|
|
|
for (size_t i = 0; i < inputs_->size(); i++) {
|
|
|
|
|
if ((*inputs_)[i] == old_anf_node) {
|
|
|
|
|
MS_LOG(INFO) << "Replace input of graph:" << graph_id_ << ", old graph input: " << old_anf_node->DebugString()
|
|
|
|
|
<< ",new graph input:" << new_anf_node->DebugString();
|
|
|
|
|
(*inputs_)[i] = new_anf_node;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Inputs of graph id:" << graph_id();
|
|
|
|
|
for (size_t i = 0; i < inputs().size(); i++) {
|
|
|
|
|
MS_LOG(INFO) << "[" << i << "]:" << inputs()[i]->DebugString();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// update front to backend map
|
|
|
|
|
FrontBackendlMapUpdate(old_anf_node, new_anf_node);
|
|
|
|
|
// update output depend relations
|
|
|
|
|
node_output_edges_[new_anf_node] = it->second;
|
|
|
|
|
(void)node_output_edges_.erase(old_anf_node);
|
|
|
|
|
}
|
|
|
|
|
// update graph inputs in child graph
|
|
|
|
|
auto it_real_inputs = real_inputs_.find(old_anf_node);
|
|
|
|
|
if (it_real_inputs != real_inputs_.end()) {
|
|
|
|
|
// insert new parameter to map
|
|
|
|
|
auto iter = real_inputs_.find(new_anf_node);
|
|
|
|
|
if (iter != real_inputs_.end()) {
|
|
|
|
|
MS_LOG(WARNING) << new_anf_node->DebugString() << " already exist in real inputs, will be rewrited.";
|
|
|
|
|
iter->second = it_real_inputs->second;
|
|
|
|
|
} else {
|
|
|
|
|
real_inputs_[new_anf_node] = it_real_inputs->second;
|
|
|
|
|
}
|
|
|
|
|
// erase old parameter in map
|
|
|
|
|
real_inputs_.erase(old_anf_node);
|
|
|
|
|
}
|
|
|
|
|
// update front to backend map
|
|
|
|
|
FrontBackendlMapUpdate(old_anf_node, new_anf_node);
|
|
|
|
|
// update output depend relations
|
|
|
|
|
node_output_edges_[new_anf_node] = it->second;
|
|
|
|
|
(void)node_output_edges_.erase(old_anf_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelGraph::UpdateExecuteKernelStreamLabel() {
|
|
|
|
@ -603,29 +639,6 @@ void KernelGraph::UpdateExecuteKernelStreamLabel() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelGraph::UpdateChildGraphOrder() {
|
|
|
|
|
MS_LOG(INFO) << "graph id:" << graph_id_;
|
|
|
|
|
auto call_nodes = FindNodeByPrimitive(std::make_shared<Primitive>(prim::kPrimCall->name()));
|
|
|
|
|
for (auto &old_child_graph : child_graph_order_) {
|
|
|
|
|
old_child_graph->set_parent_graph(nullptr);
|
|
|
|
|
}
|
|
|
|
|
child_graph_order_.clear();
|
|
|
|
|
for (auto &call_node : call_nodes) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(call_node);
|
|
|
|
|
auto call_child_graphs = AnfAlgo ::GetCallNodeKernelGraph(call_node->cast<CNodePtr>());
|
|
|
|
|
for (const auto &child_graph : call_child_graphs) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(child_graph);
|
|
|
|
|
if (child_graph != parent_graph()) {
|
|
|
|
|
child_graph->set_parent_graph(shared_from_this()->cast<std::shared_ptr<KernelGraph>>());
|
|
|
|
|
child_graph_order_.push_back(child_graph);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < child_graph_order_.size(); i++) {
|
|
|
|
|
MS_LOG(INFO) << "child graph[" << i << "][id:" << child_graph_order_[i]->graph_id() << "]";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
|
|
|
|
|
std::vector<std::shared_ptr<KernelGraph>> leaf_graph_order;
|
|
|
|
|
if (IsLeafGraph()) {
|
|
|
|
@ -643,9 +656,8 @@ std::vector<std::shared_ptr<KernelGraph>> KernelGraph::GetLeafGraphOrder() {
|
|
|
|
|
bool KernelGraph::IsLeafGraph() const { return child_graph_order_.empty(); }
|
|
|
|
|
|
|
|
|
|
std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primitive) const {
|
|
|
|
|
auto anf_list = TopoSort(get_return());
|
|
|
|
|
std::vector<CNodePtr> result;
|
|
|
|
|
for (const auto &anf : anf_list) {
|
|
|
|
|
for (const auto &anf : execution_order_) {
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(anf, primitive) && AnfAlgo::GetGraphId(anf.get()) == graph_id_) {
|
|
|
|
|
result.push_back(anf->cast<CNodePtr>());
|
|
|
|
|
}
|
|
|
|
@ -653,14 +665,6 @@ std::vector<CNodePtr> KernelGraph::FindNodeByPrimitive(const PrimitivePtr &primi
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr ¶meter) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
|
if (real_inputs_.find(parameter) == real_inputs_.end()) {
|
|
|
|
|
return {};
|
|
|
|
|
}
|
|
|
|
|
return real_inputs_[parameter];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &arg) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(arg);
|
|
|
|
@ -674,37 +678,41 @@ void KernelGraph::SetRealInput(const AnfNodePtr ¶meter, const AnfNodePtr &ar
|
|
|
|
|
(void)args.insert(arg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::set<AnfNodePtr> KernelGraph::GetRealInput(const AnfNodePtr ¶meter) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
|
auto iter = real_inputs_.find(parameter);
|
|
|
|
|
if (iter != real_inputs_.end()) {
|
|
|
|
|
return iter->second;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION) << parameter->DebugString() << " not found.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelGraph::UpdateCallRealInput() {
|
|
|
|
|
MS_LOG(INFO) << "Update graph id: " << graph_id_;
|
|
|
|
|
for (auto &it : real_inputs_) {
|
|
|
|
|
auto ¶meter = it.first;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
|
auto &real_inputs = it.second;
|
|
|
|
|
std::set<AnfNodePtr> new_real_inputs;
|
|
|
|
|
std::vector<AnfNodePtr> new_real_inputs;
|
|
|
|
|
std::set<AnfNodePtr> erase_real_inputs;
|
|
|
|
|
for (auto &real_input : real_inputs) {
|
|
|
|
|
// if real input is a call node ,find the child graph output act as the new real input
|
|
|
|
|
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(real_input, 0);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(item_with_index.first);
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimCall)) {
|
|
|
|
|
MS_LOG(INFO) << "paramter: " << parameter->DebugString()
|
|
|
|
|
<< " erase real input:" << item_with_index.first->DebugString();
|
|
|
|
|
(void)erase_real_inputs.insert(item_with_index.first);
|
|
|
|
|
auto call_node_outputs = GetCallRealOutputs(item_with_index.first);
|
|
|
|
|
for (auto &call_node_output : call_node_outputs) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(call_node_output);
|
|
|
|
|
MS_LOG(INFO) << "paramter: " << parameter->DebugString()
|
|
|
|
|
<< " insert real input:" << call_node_output->DebugString();
|
|
|
|
|
(void)new_real_inputs.insert(call_node_output);
|
|
|
|
|
}
|
|
|
|
|
new_real_inputs = GetCallRealOutputs(item_with_index.first);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
for (auto &erase_node : erase_real_inputs) {
|
|
|
|
|
(void)real_inputs.erase(erase_node);
|
|
|
|
|
}
|
|
|
|
|
for (auto &new_real_input : new_real_inputs) {
|
|
|
|
|
(void)real_inputs.insert(new_real_input);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (auto &erase_node : erase_real_inputs) {
|
|
|
|
|
MS_LOG(INFO) << "paramter: " << parameter->DebugString() << " erase real input:" << erase_node->DebugString();
|
|
|
|
|
(void)real_inputs.erase(erase_node);
|
|
|
|
|
}
|
|
|
|
|
for (auto &new_real_input : new_real_inputs) {
|
|
|
|
|
MS_LOG(INFO) << "paramter: " << parameter->DebugString()
|
|
|
|
|
<< " insert real input:" << new_real_input->DebugString();
|
|
|
|
|
(void)real_inputs.insert(new_real_input);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|