|
|
@ -281,24 +281,6 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
|
|
|
return node_adjoint;
|
|
|
|
return node_adjoint;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void TensorSetAddress(const ValuePtr &value, std::map<std::string, tensor::TensorPtr> *tuple_tensors) {
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Start set tensor address" << value->ToString() << value->isa<tensor::Tensor>();
|
|
|
|
|
|
|
|
if (value->isa<tensor::Tensor>()) {
|
|
|
|
|
|
|
|
auto tnode = value->cast<tensor::TensorPtr>();
|
|
|
|
|
|
|
|
if (tuple_tensors->find(tnode->id()) != tuple_tensors->end()) {
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Set tensor" << tnode->device_address();
|
|
|
|
|
|
|
|
(*tuple_tensors)[tnode->id()]->set_device_address(tnode->device_address());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (value->isa<ValueTuple>()) {
|
|
|
|
|
|
|
|
auto tuple = value->cast<ValueTuplePtr>();
|
|
|
|
|
|
|
|
for (size_t i = 0; i < tuple->size(); i++) {
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Set tuple tensor" << (*tuple)[i]->ToString();
|
|
|
|
|
|
|
|
TensorSetAddress((*tuple)[i], tuple_tensors);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ValuePtr GenNewTensorInner(const ValuePtr &value) {
|
|
|
|
ValuePtr GenNewTensorInner(const ValuePtr &value) {
|
|
|
|
std::vector<ValuePtr> value_list;
|
|
|
|
std::vector<ValuePtr> value_list;
|
|
|
|
if (value->isa<tensor::Tensor>()) {
|
|
|
|
if (value->isa<tensor::Tensor>()) {
|
|
|
@ -328,7 +310,6 @@ ValuePtr GenNewTensor(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, co
|
|
|
|
|
|
|
|
|
|
|
|
void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph) {
|
|
|
|
void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph) {
|
|
|
|
auto forward = cnode_morph->forward().first;
|
|
|
|
auto forward = cnode_morph->forward().first;
|
|
|
|
auto forward_id = cnode_morph->forward().second;
|
|
|
|
|
|
|
|
if (forward == nullptr) {
|
|
|
|
if (forward == nullptr) {
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -337,6 +318,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto fg = GetValueNode<FuncGraphPtr>(input);
|
|
|
|
auto fg = GetValueNode<FuncGraphPtr>(input);
|
|
|
|
|
|
|
|
// {prim::maketuple, forward_output, bprop_graph}
|
|
|
|
auto output = fg->output();
|
|
|
|
auto output = fg->output();
|
|
|
|
if (!output->isa<CNode>()) {
|
|
|
|
if (!output->isa<CNode>()) {
|
|
|
|
return;
|
|
|
|
return;
|
|
|
@ -350,25 +332,22 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
|
|
|
|
if (!IsValueNode<FuncGraph>(input_fg)) {
|
|
|
|
if (!IsValueNode<FuncGraph>(input_fg)) {
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
std::map<std::string, tensor::TensorPtr> tuple_tensors;
|
|
|
|
// replace forward output with value node
|
|
|
|
auto equivdout = cnode_input->cast<CNodePtr>();
|
|
|
|
auto equivdout = cnode_input->cast<CNodePtr>();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(equivdout);
|
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(input_fg);
|
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(input_fg);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
auto manager = Manage({fg, func_graph}, false);
|
|
|
|
auto manager = Manage({fg, func_graph}, false);
|
|
|
|
auto ref_size = manager->node_users()[equivdout].size();
|
|
|
|
|
|
|
|
auto forward_value = forward;
|
|
|
|
|
|
|
|
if (!forward_id.empty() && ref_size > 1) {
|
|
|
|
|
|
|
|
auto inst = pynative::PynativeExecutor::GetInstance();
|
|
|
|
|
|
|
|
inst->SaveOpForwardValue(forward_id, forward_value, &tuple_tensors);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
forward_value = GenNewTensor(manager, equivdout, forward);
|
|
|
|
auto forward_value = GenNewTensor(manager, equivdout, forward);
|
|
|
|
MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward;
|
|
|
|
MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward;
|
|
|
|
auto value_node = NewValueNode(forward_value);
|
|
|
|
auto value_node = NewValueNode(forward_value);
|
|
|
|
value_node->set_has_new_value(true);
|
|
|
|
value_node->set_has_new_value(true);
|
|
|
|
manager->Replace(equivdout, value_node);
|
|
|
|
manager->Replace(equivdout, value_node);
|
|
|
|
|
|
|
|
// replace input object with value node
|
|
|
|
auto paras = fg->parameters();
|
|
|
|
auto paras = fg->parameters();
|
|
|
|
auto inputs_value = cnode_morph->inputs_value();
|
|
|
|
auto inputs_value = cnode_morph->inputs_value();
|
|
|
|
if (inputs_value.size() == 0) {
|
|
|
|
if (inputs_value.empty()) {
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (inputs_value.size() != paras.size()) {
|
|
|
|
if (inputs_value.size() != paras.size()) {
|
|
|
@ -379,10 +358,6 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
|
|
|
|
auto input_value = inputs_value[i];
|
|
|
|
auto input_value = inputs_value[i];
|
|
|
|
if (para_ref_size > 0 && input_value.first != nullptr) {
|
|
|
|
if (para_ref_size > 0 && input_value.first != nullptr) {
|
|
|
|
MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first;
|
|
|
|
MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first;
|
|
|
|
auto inst = pynative::PynativeExecutor::GetInstance();
|
|
|
|
|
|
|
|
if (!input_value.second.empty()) {
|
|
|
|
|
|
|
|
inst->SaveOpForwardValue(input_value.second, input_value.first, &tuple_tensors);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto input_value_node = NewValueNode(input_value.first);
|
|
|
|
auto input_value_node = NewValueNode(input_value.first);
|
|
|
|
input_value_node->set_has_new_value(true);
|
|
|
|
input_value_node->set_has_new_value(true);
|
|
|
|
manager->Replace(paras[i], input_value_node);
|
|
|
|
manager->Replace(paras[i], input_value_node);
|
|
|
@ -394,30 +369,19 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
|
|
|
|
res->set_func_graph(fg);
|
|
|
|
res->set_func_graph(fg);
|
|
|
|
PynativeElimOpt(res);
|
|
|
|
PynativeElimOpt(res);
|
|
|
|
auto out = fg->output()->cast<CNodePtr>();
|
|
|
|
auto out = fg->output()->cast<CNodePtr>();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(out);
|
|
|
|
auto c_input = out->input(1);
|
|
|
|
auto c_input = out->input(1);
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(c_input);
|
|
|
|
if (!c_input->isa<ValueNode>()) {
|
|
|
|
if (!c_input->isa<ValueNode>()) {
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
auto out_node = c_input->cast<ValueNodePtr>();
|
|
|
|
auto out_node = c_input->cast<ValueNodePtr>();
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(out_node);
|
|
|
|
out_node->set_value(GenNewTensor(manager, out_node, out_node->value()));
|
|
|
|
out_node->set_value(GenNewTensor(manager, out_node, out_node->value()));
|
|
|
|
|
|
|
|
// clear resource
|
|
|
|
cnode_morph->clear_inputs_value();
|
|
|
|
cnode_morph->clear_inputs_value();
|
|
|
|
|
|
|
|
|
|
|
|
if (tuple_tensors.size() != 0) {
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Start tuple out" << fg->output()->DebugString(4);
|
|
|
|
|
|
|
|
for (auto &g : manager->func_graphs()) {
|
|
|
|
|
|
|
|
for (auto &node : g->value_nodes()) {
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "Set Tensor addr" << node.first->ToString();
|
|
|
|
|
|
|
|
auto vnode = node.first->cast<ValueNodePtr>()->value();
|
|
|
|
|
|
|
|
TensorSetAddress(vnode, &tuple_tensors);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fg->ClearAllManagerInfo();
|
|
|
|
fg->ClearAllManagerInfo();
|
|
|
|
func_graph->ClearAllManagerInfo();
|
|
|
|
func_graph->ClearAllManagerInfo();
|
|
|
|
return;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
|
|
|
|
bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
|
|
|
|