|
|
|
@ -25,6 +25,7 @@
|
|
|
|
|
#include "ir/func_graph_cloner.h"
|
|
|
|
|
#include "ir/manager.h"
|
|
|
|
|
#include "pipeline/jit/resource.h"
|
|
|
|
|
#include "pipeline/pynative/pynative_execute.h"
|
|
|
|
|
#include "frontend/optimizer/ad/adjoint.h"
|
|
|
|
|
#include "frontend/operator/ops.h"
|
|
|
|
|
#include "utils/symbolic.h"
|
|
|
|
@ -218,7 +219,8 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
|
|
|
|
TraceManager::DebugTrace(std::make_shared<TraceGradFpropApp>(cnode_morph->debug_info()));
|
|
|
|
|
auto k_app = k_graph_->NewCNode(inputs);
|
|
|
|
|
TraceManager::EndTrace();
|
|
|
|
|
ReplaceEquivdout(k_app, cnode_morph->forward());
|
|
|
|
|
ReplaceEquivdout(k_app, cnode_morph);
|
|
|
|
|
cnode_morph->set_forward(nullptr, "");
|
|
|
|
|
for (size_t i = 0; i < param_adjoints.size(); ++i) {
|
|
|
|
|
param_adjoints[i]->RegisterKUser(k_app, i);
|
|
|
|
|
}
|
|
|
|
@ -240,7 +242,9 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
|
|
|
|
return node_adjoint;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const ValuePtr &forward) {
|
|
|
|
|
void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph) {
|
|
|
|
|
auto forward = cnode_morph->forward().first;
|
|
|
|
|
auto forward_id = cnode_morph->forward().second;
|
|
|
|
|
if (forward == nullptr) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -265,10 +269,44 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const ValuePtr &forward)
|
|
|
|
|
auto equivdout = cnode_input->cast<CNodePtr>();
|
|
|
|
|
auto func_graph = GetValueNode<FuncGraphPtr>(input_fg);
|
|
|
|
|
auto manager = Manage({fg, func_graph}, false);
|
|
|
|
|
auto ref_size = manager->node_users()[equivdout].size();
|
|
|
|
|
auto forward_value = forward;
|
|
|
|
|
if (!forward_id.empty() && ref_size > 1) {
|
|
|
|
|
auto inst = pynative::PynativeExecutor::GetInstance();
|
|
|
|
|
inst->SaveOpForwardValue(forward_id, forward_value);
|
|
|
|
|
}
|
|
|
|
|
if (ref_size < 2) {
|
|
|
|
|
auto tensor = forward->cast<tensor::TensorPtr>();
|
|
|
|
|
if (tensor != nullptr) {
|
|
|
|
|
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape());
|
|
|
|
|
forward_value = new_tensor;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward;
|
|
|
|
|
auto value_node = NewValueNode(forward);
|
|
|
|
|
auto value_node = NewValueNode(forward_value);
|
|
|
|
|
value_node->set_has_new_value(true);
|
|
|
|
|
manager->Replace(equivdout, value_node);
|
|
|
|
|
auto paras = fg->parameters();
|
|
|
|
|
auto inputs_value = cnode_morph->inputs_value();
|
|
|
|
|
if (inputs_value.size() == 0) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (inputs_value.size() != paras.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Parameter size:" << paras.size() << " is not equal to inputs size:" << inputs_value.size();
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < paras.size(); i++) {
|
|
|
|
|
auto para_ref_size = manager->node_users()[paras[i]].size();
|
|
|
|
|
auto input_value = inputs_value[i];
|
|
|
|
|
if (para_ref_size > 0 && input_value.first != nullptr) {
|
|
|
|
|
MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first;
|
|
|
|
|
auto inst = pynative::PynativeExecutor::GetInstance();
|
|
|
|
|
inst->SaveOpForwardValue(input_value.second, input_value.first);
|
|
|
|
|
auto input_value_node = NewValueNode(input_value.first);
|
|
|
|
|
manager->Replace(paras[i], input_value_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
cnode_morph->clear_inputs_value();
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) {
|
|
|
|
|