|
|
|
@ -53,6 +53,37 @@ using Constant = ge::op::Constant;
|
|
|
|
|
using Assign = ge::op::Assign;
|
|
|
|
|
using Data = ge::op::Data;
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
std::vector<AnfNodePtr> GetOrderedCNodes(const FuncGraphPtr fg) {
|
|
|
|
|
auto BelongSameGraph = std::bind(IncludeBelongGraph, fg, std::placeholders::_1);
|
|
|
|
|
auto succ_include_fv = [&fg](const AnfNodePtr &node) -> std::vector<AnfNodePtr> {
|
|
|
|
|
std::vector<AnfNodePtr> vecs;
|
|
|
|
|
if (node == nullptr) {
|
|
|
|
|
return vecs;
|
|
|
|
|
}
|
|
|
|
|
if (node->isa<CNode>()) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
auto &inputs = cnode->inputs();
|
|
|
|
|
// Check if free variables used.
|
|
|
|
|
for (const auto &input : inputs) {
|
|
|
|
|
auto input_fg = GetValueNode<FuncGraphPtr>(input);
|
|
|
|
|
if (input_fg) {
|
|
|
|
|
for (auto &fv : input_fg->free_variables_nodes()) {
|
|
|
|
|
if (fv->func_graph() == fg && fg->nodes().contains(fv)) {
|
|
|
|
|
vecs.push_back(fv);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
(void)vecs.insert(vecs.end(), inputs.begin(), inputs.end());
|
|
|
|
|
}
|
|
|
|
|
return vecs;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
return TopoSort(fg->get_return(), succ_include_fv, BelongSameGraph);
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
// ---------------implement of DfGraphConvertor-------------
|
|
|
|
|
PrimType GetCNodeFuncType(const CNodePtr cnode) {
|
|
|
|
|
if (cnode->inputs().empty()) {
|
|
|
|
@ -214,7 +245,7 @@ void DfGraphConvertor::DrawParamInitSubGraph(const std::string &name, const AnfN
|
|
|
|
|
|
|
|
|
|
void DfGraphConvertor::SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector<ge::Operator> *init_input) {
|
|
|
|
|
DfGraphPtr init_graph = std::make_shared<DfGraph>("init");
|
|
|
|
|
std::vector<AnfNodePtr> nodes = TopoSort(anf_graph_->get_return());
|
|
|
|
|
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
|
|
|
|
|
|
|
|
|
|
for (auto &it : nodes) {
|
|
|
|
|
if (it->isa<ValueNode>()) {
|
|
|
|
@ -549,7 +580,7 @@ DfGraphConvertor &DfGraphConvertor::ConvertAllNode() {
|
|
|
|
|
|
|
|
|
|
// Convert all anf node to Operator
|
|
|
|
|
MS_LOG(DEBUG) << "convert all node";
|
|
|
|
|
std::vector<AnfNodePtr> nodes = TopoSort(anf_graph_->get_return());
|
|
|
|
|
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
|
|
|
|
|
for (auto &it : nodes) {
|
|
|
|
|
(void)Convert(it);
|
|
|
|
|
if (this->error_ != 0) {
|
|
|
|
@ -811,7 +842,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Case node set input.
|
|
|
|
|
std::vector<AnfNodePtr> nodes = ::mindspore::TopoSort(anf_graph_->get_return());
|
|
|
|
|
std::vector<AnfNodePtr> nodes = GetOrderedCNodes(anf_graph_);
|
|
|
|
|
for (auto &it : nodes) {
|
|
|
|
|
if (it->isa<CNode>() && IsCaseNode(it->cast<CNodePtr>())) {
|
|
|
|
|
auto node = it->cast<CNodePtr>();
|
|
|
|
@ -825,7 +856,7 @@ DfGraphConvertor &DfGraphConvertor::BuildGraph() {
|
|
|
|
|
|
|
|
|
|
// set up dependencies
|
|
|
|
|
MS_LOG(DEBUG) << "set up dependencies";
|
|
|
|
|
nodes = ::mindspore::TopoSort(anf_graph_->get_return());
|
|
|
|
|
nodes = GetOrderedCNodes(anf_graph_);
|
|
|
|
|
for (auto &it : nodes) {
|
|
|
|
|
SetNodeInput(it);
|
|
|
|
|
SetOpControlInput(it);
|
|
|
|
@ -1195,6 +1226,51 @@ void DfGraphConvertor::SetTupleOpInput(const OpAdapterPtr &adpt, const CNodePtr
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(WARNING) << "This anf node is not supported as a tuple item : " << node->ToString();
|
|
|
|
|
}
|
|
|
|
|
AnfNodePtr DfGraphConvertor::GetRealInputNode(const CNodePtr &node, const AnfNodePtr &input) {
|
|
|
|
|
if (input == nullptr || node == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
AnfNodePtr pred = input;
|
|
|
|
|
while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == prim::kPrimDepend->name()) {
|
|
|
|
|
pred = pred->cast<CNodePtr>()->input(1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// skip input of UMonad, IOMonad
|
|
|
|
|
if (IsValueNode<UMonad>(pred) || IsValueNode<IOMonad>(pred)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// skip input of the None, UpdateState
|
|
|
|
|
if (IsValueNode<None>(pred) || IsPrimitiveCNode(pred, prim::kPrimUpdateState)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (IsPrimitiveCNode(pred, prim::kPrimLoad)) {
|
|
|
|
|
pred = ParseLoadInput(pred->cast<CNodePtr>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// transform "Const" op to "Variable" op when the next node is "Assign" op.
|
|
|
|
|
std::string c_name = GetCNodeTargetFuncName(node);
|
|
|
|
|
auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name);
|
|
|
|
|
if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) {
|
|
|
|
|
std::string name = std::static_pointer_cast<Parameter>(pred)->name();
|
|
|
|
|
auto op_itor = op_cache_.find(pred.get());
|
|
|
|
|
if (op_itor == op_cache_.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << ".";
|
|
|
|
|
}
|
|
|
|
|
if (op_itor->second != nullptr &&
|
|
|
|
|
(op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") &&
|
|
|
|
|
vars_.find(name) != vars_.end()) {
|
|
|
|
|
auto variable = std::make_shared<Variable>(name);
|
|
|
|
|
auto desc = vars_[name]->GetOutputDesc("y");
|
|
|
|
|
(void)variable->update_output_desc_y(desc);
|
|
|
|
|
MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << ".";
|
|
|
|
|
op_itor->second = variable; // replace parameter with variable
|
|
|
|
|
vars_[name] = variable;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return pred;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
|
|
|
|
|
OperatorPtr src = Convert(node);
|
|
|
|
@ -1213,45 +1289,11 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
|
|
|
|
|
} else {
|
|
|
|
|
pred = inputs[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
while (pred->isa<CNode>() && GetCNodeTargetFuncName(pred->cast<CNodePtr>()) == prim::kPrimDepend->name()) {
|
|
|
|
|
pred = pred->cast<CNodePtr>()->input(1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// skip input of UMonad, IOMonad
|
|
|
|
|
if (IsValueNode<UMonad>(pred) || IsValueNode<IOMonad>(pred)) {
|
|
|
|
|
pred = GetRealInputNode(node, pred);
|
|
|
|
|
if (pred == nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// skip input of the None, Load, UpdateState
|
|
|
|
|
if (IsValueNode<None>(pred) || IsPrimitiveCNode(pred, prim::kPrimUpdateState)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (IsPrimitiveCNode(pred, prim::kPrimLoad)) {
|
|
|
|
|
pred = ParseLoadInput(pred->cast<CNodePtr>());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// transform "Const" op to "Variable" op when the next node is "Assign" op.
|
|
|
|
|
std::string c_name = GetCNodeTargetFuncName(node);
|
|
|
|
|
auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name);
|
|
|
|
|
if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) {
|
|
|
|
|
std::string name = std::static_pointer_cast<Parameter>(pred)->name();
|
|
|
|
|
auto op_itor = op_cache_.find(pred.get());
|
|
|
|
|
if (op_itor == op_cache_.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << ".";
|
|
|
|
|
}
|
|
|
|
|
if (op_itor->second != nullptr &&
|
|
|
|
|
(op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") &&
|
|
|
|
|
vars_.find(name) != vars_.end()) {
|
|
|
|
|
auto variable = std::make_shared<Variable>(name);
|
|
|
|
|
auto desc = vars_[name]->GetOutputDesc("y");
|
|
|
|
|
(void)variable->update_output_desc_y(desc);
|
|
|
|
|
MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << ".";
|
|
|
|
|
op_itor->second = variable; // replace parameter with variable
|
|
|
|
|
vars_[name] = variable;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
int index = SizeToInt(i);
|
|
|
|
|
// find in out_hadnle_cache_ first
|
|
|
|
|
auto it = out_handle_cache_.find(pred.get());
|
|
|
|
|