|
|
|
@ -1155,6 +1155,9 @@ void DfGraphConvertor::SetOpControlInput(const AnfNodePtr node) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string> trans_var_list = {prim::kPrimAssign->name(), string(kNameAssignAdd),
|
|
|
|
|
string(kNameAssignSub)};
|
|
|
|
|
|
|
|
|
|
void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node) {
|
|
|
|
|
OperatorPtr src = Convert(node);
|
|
|
|
|
auto &inputs = node->inputs();
|
|
|
|
@ -1167,6 +1170,26 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
|
|
|
|
|
if (IsValueNode<None>(pred)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// transform "Const" op to "Variable" op when the next node is "Assign" op.
|
|
|
|
|
std::string c_name = GetCNodeFuncName(node);
|
|
|
|
|
auto pos = std::find(trans_var_list.begin(), trans_var_list.end(), c_name);
|
|
|
|
|
if (!training_ && pos != trans_var_list.end() && pred->isa<Parameter>()) {
|
|
|
|
|
std::string name = std::static_pointer_cast<Parameter>(pred)->name();
|
|
|
|
|
auto op_itor = op_cache_.find(pred.get());
|
|
|
|
|
if (op_itor == op_cache_.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Can not find op for node " << pred->ToString() << ".";
|
|
|
|
|
}
|
|
|
|
|
if (op_itor->second != nullptr &&
|
|
|
|
|
(op_itor->second->GetOpType() == "Constant" || op_itor->second->GetOpType() == "Const") &&
|
|
|
|
|
vars_.find(name) != vars_.end()) {
|
|
|
|
|
auto variable = std::make_shared<Variable>(name);
|
|
|
|
|
auto desc = vars_[name]->GetOutputDesc("y");
|
|
|
|
|
(void)variable->update_output_desc_y(desc);
|
|
|
|
|
MS_LOG(DEBUG) << "Trans to variable, var = " << variable->GetName() << ".";
|
|
|
|
|
op_itor->second = variable; // replace parameter with variable
|
|
|
|
|
vars_[name] = variable;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// find in out_hadnle_cache_ first
|
|
|
|
|
auto it = out_handle_cache_.find(pred.get());
|
|
|
|
|
if (it != out_handle_cache_.end()) {
|
|
|
|
|