|
|
|
@ -53,6 +53,37 @@ using Constant = ge::op::Constant;
|
|
|
|
|
using Assign = ge::op::Assign;
|
|
|
|
|
using Data = ge::op::Data;
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
std::vector<AnfNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
|
|
|
|
|
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
|
|
|
|
|
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
|
|
|
|
|
std::vector<AnfNodePtr> vecs;
|
|
|
|
|
if (node == nullptr) {
|
|
|
|
|
return vecs;
|
|
|
|
|
}
|
|
|
|
|
if (node->isa<CNode>()) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
auto &inputs = cnode->inputs();
|
|
|
|
|
// Check if free variables used.
|
|
|
|
|
for (const auto &input : inputs) {
|
|
|
|
|
auto input_fg = GetValueNode<FuncGraphPtr>(input);
|
|
|
|
|
if (input_fg) {
|
|
|
|
|
for (auto &fv : input_fg->free_variables_nodes()) {
|
|
|
|
|
if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
|
|
|
|
|
vecs.push_back(fv);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
(void)vecs.insert(vecs.end(), inputs.begin(), inputs.end());
|
|
|
|
|
}
|
|
|
|
|
return vecs;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
return TopoSort(fg->get_return(), succ_include_fv, BelongSameGraph);
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// ---------------implement of DfGraphConvertor-------------
|
|
|
|
|
PrimType GetCNodeFuncType(const CNodePtr cnode) {
|
|
|
|
|
if (cnode->inputs().empty()) {
|
|
|
|
@ -214,7 +245,7 @@ void DfGraphConvertor::DrawParamInitSubGraph(const std::string &name, const AnfN
|
|
|
|
|
|
|
|
|
|
void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector<ge::Operator> *init_input) {
|
|
|
|
|
DfGraphPtr init_graph = std::make_shared<DfGraph>("init");
|
|
|
|
|
std::vector<AnfNodePtr> nodes = TopoSort(anf_graph_->get_return());
|
|
|
|
|
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
|
|
|
|
|
|
|
|
|
|
for (auto &it : nodes) {
|
|
|
|
|
if (it->isa<ValueNode>()) {
|
|
|
|
@ -549,7 +580,7 @@ DfGraphConvertor &DfGraphConvertor::ConvertAllNode() {
|
|
|
|
|
|
|
|
|
|
// Convert all anf node to Operator
|
|
|
|
|
MS_LOG(DEBUG) << "convert all node";
|
|
|
|
|
std::vector<AnfNodePtr> nodes = TopoSort(anf_graph_->get_return());
|
|
|
|
|
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
|
|
|
|
|
for (auto &it : nodes) {
|
|
|
|
|
(void)Convert(it);
|
|
|
|
|
if (this->error_ != 0) {
|
|
|
|
@ -811,7 +842,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Case node set input.
|
|
|
|
|
std::vector<AnfNodePtr> nodes = ::mindspore::TopoSort(anf_graph_->get_return());
|
|
|
|
|
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
|
|
|
|
|
for (auto &it : nodes) {
|
|
|
|
|
if (it->isa<CNode>() && IsCaseNode(it->cast<CNodePtr>())) {
|
|
|
|
|
auto node = it->cast<CNodePtr>();
|
|
|
|
@ -825,7 +856,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
|
|
|
|
|
|
|
|
|
|
// set up dependencies
|
|
|
|
|
MS_LOG(DEBUG) << "set up dependencies";
|
|
|
|
|
nodes = ::mindspore::TopoSort(anf_graph_->get_return());
|
|
|
|
|
nodes = GetOrderedCNodes(anf_graph_);
|
|
|
|
|
for (auto &it : nodes) {
|
|
|
|
|
SetNodeInput(it);
|
|
|
|
|
SetOpControlInput(it);
|
|
|
|
@ -1195,37 +1226,23 @@ void DfGraphConvertor::SetTupleOpInput(const OpAdapterPtr &adpt, const CNodePtr
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(WARNING) << "This anf node is not supported as a tuple item : " << node->ToString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
|
|
|
|
|
OperatorPtr src = Convert(node);
|
|
|
|
|
int case_flag = 0;
|
|
|
|
|
auto &inputs = node->inputs();
|
|
|
|
|
size_t input_size = inputs.size();
|
|
|
|
|
if (case_input_handle_cache_.find(node.get()) != case_input_handle_cache_.end()) {
|
|
|
|
|
case_flag = 1;
|
|
|
|
|
input_size = case_input_handle_cache_[node.get()]->size() + 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 1; i < input_size; i++) {
|
|
|
|
|
AnfNodePtr pred = nullptr;
|
|
|
|
|
if (case_flag != 0) {
|
|
|
|
|
pred = case_input_handle_cache_[node.get()]->at(i - 1);
|
|
|
|
|
} else {
|
|
|
|
|
pred = inputs[i];
|
|
|
|
|
AnfNodePtr DfGraphConvertor::GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input) {
|
|
|
|
|
if (input == nullptr || node == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr pred = input;
|
|
|
|
|
while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == prim::kPrimDepend->name()) {
|
|
|
|
|
pred = pred->cast<CNodePtr>()->input(1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// skip input of UMonad, IOMonad
|
|
|
|
|
if (IsValueNode<UMonad>(pred) || IsValueNode<IOMonad>(pred)) {
|
|
|
|
|
continue;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// skip input of the None, Load, UpdateState
|
|
|
|
|
// skip input of the None, UpdateState
|
|
|
|
|
if (IsValueNode<None>(pred) || IsPrimitiveCNode(pred, prim::kPrimUpdateState)) {
|
|
|
|
|
continue;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (IsPrimitiveCNode(pred, prim::kPrimLoad)) {
|
|
|
|
@ -1252,6 +1269,31 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
|
|
|
|
|
vars_[name] = variable;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return pred;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
|
|
|
|
|
OperatorPtr src = Convert(node);
|
|
|
|
|
int case_flag = 0;
|
|
|
|
|
auto &inputs = node->inputs();
|
|
|
|
|
size_t input_size = inputs.size();
|
|
|
|
|
if (case_input_handle_cache_.find(node.get()) != case_input_handle_cache_.end()) {
|
|
|
|
|
case_flag = 1;
|
|
|
|
|
input_size = case_input_handle_cache_[node.get()]->size() + 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 1; i < input_size; i++) {
|
|
|
|
|
AnfNodePtr pred = nullptr;
|
|
|
|
|
if (case_flag != 0) {
|
|
|
|
|
pred = case_input_handle_cache_[node.get()]->at(i - 1);
|
|
|
|
|
} else {
|
|
|
|
|
pred = inputs[i];
|
|
|
|
|
}
|
|
|
|
|
pred = GetRealInputNode(node, pred);
|
|
|
|
|
if (pred == nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int index = SizeToInt(i);
|
|
|
|
|
// find in out_hadnle_cache_ first
|
|
|
|
|
auto it = out_handle_cache_.find(pred.get());
|
|
|
|
|