|
|
|
@ -19,6 +19,7 @@
|
|
|
|
|
#include "backend/optimizer/common/helper.h"
|
|
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
|
#include "utils/utils.h"
|
|
|
|
|
#include "utils/trace_base.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
@ -41,7 +42,8 @@ void GetBNOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, std::vect
|
|
|
|
|
auto manager = func_graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
if (manager->node_users().find(bn) == manager->node_users().end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The bn node " << bn->DebugString() << " should has some outputs";
|
|
|
|
|
MS_LOG(EXCEPTION) << "The bn node " << bn->DebugString() << " should has some outputs"
|
|
|
|
|
<< " trace: " << trace::DumpSourceLines(bn);
|
|
|
|
|
}
|
|
|
|
|
for (const auto &node_index : manager->node_users()[bn]) {
|
|
|
|
|
const AnfNodePtr &output = node_index.first;
|
|
|
|
@ -113,7 +115,8 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func
|
|
|
|
|
// Set input to create node
|
|
|
|
|
auto iter_data_input0 = (*equiv).find(data_input0_var_);
|
|
|
|
|
if (iter_data_input0 == (*equiv).end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched.";
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input0 var after matched."
|
|
|
|
|
<< " trace: " << trace::DumpSourceLines(node);
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
|
|
|
|
|
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)),
|
|
|
|
@ -124,13 +127,15 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingReduce(const FuncGraphPtr &func
|
|
|
|
|
// Set abstract
|
|
|
|
|
auto iter_data_input1 = (*equiv).find(data_input1_var_);
|
|
|
|
|
if (iter_data_input1 == (*equiv).end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched.";
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input1 var after matched."
|
|
|
|
|
<< " trace: " << trace::DumpSourceLines(node);
|
|
|
|
|
}
|
|
|
|
|
auto data_input1 = utils::cast<AnfNodePtr>(iter_data_input1->second);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(data_input1);
|
|
|
|
|
auto iter_data_input2 = (*equiv).find(data_input2_var_);
|
|
|
|
|
if (iter_data_input2 == (*equiv).end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched.";
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the data_input2 var after matched."
|
|
|
|
|
<< " trace: " << trace::DumpSourceLines(node);
|
|
|
|
|
}
|
|
|
|
|
auto data_input2 = utils::cast<AnfNodePtr>(iter_data_input2->second);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(data_input2);
|
|
|
|
@ -190,17 +195,19 @@ void FusedBatchNormFusion::GetBNTrainingUpdateAbstractList(const EquivPtr &equiv
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bn_abstract_tuple);
|
|
|
|
|
if (bn_abstract_tuple->elements().size() < kBnOutputNum) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The abstract size of node bn must not be less than " << kBnOutputNum << ", but it is "
|
|
|
|
|
<< bn_abstract_tuple->elements().size();
|
|
|
|
|
<< bn_abstract_tuple->elements().size() << " trace: " << trace::DumpSourceLines(bn);
|
|
|
|
|
}
|
|
|
|
|
auto iter_variable_input0 = (*equiv).find(variable_input0_var_);
|
|
|
|
|
if (iter_variable_input0 == (*equiv).end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched.";
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input0 var after matched."
|
|
|
|
|
<< " trace: " << trace::DumpSourceLines(bn);
|
|
|
|
|
}
|
|
|
|
|
auto variable_input0 = utils::cast<AnfNodePtr>(iter_variable_input0->second);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(variable_input0);
|
|
|
|
|
auto iter_variable_input1 = (*equiv).find(variable_input1_var_);
|
|
|
|
|
if (iter_variable_input1 == (*equiv).end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched.";
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the variable_input1 var after matched."
|
|
|
|
|
<< " trace: " << trace::DumpSourceLines(bn);
|
|
|
|
|
}
|
|
|
|
|
auto variable_input1 = utils::cast<AnfNodePtr>(iter_variable_input1->second);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(variable_input1);
|
|
|
|
@ -222,7 +229,8 @@ AnfNodePtr FusedBatchNormFusion::CreateBNTrainingUpdate(
|
|
|
|
|
// Set abstract
|
|
|
|
|
auto iter_batch_norm = (*equiv).find(batch_norm_var_);
|
|
|
|
|
if (iter_batch_norm == (*equiv).end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched.";
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched."
|
|
|
|
|
<< " trace: " << trace::DumpSourceLines(node);
|
|
|
|
|
}
|
|
|
|
|
AnfNodePtr bn = utils::cast<AnfNodePtr>(iter_batch_norm->second);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bn);
|
|
|
|
@ -260,12 +268,13 @@ const AnfNodePtr FusedBatchNormFusion::Process(const FuncGraphPtr &func_graph, c
|
|
|
|
|
&bn_training_update_outputs);
|
|
|
|
|
if (bn_training_update_outputs.size() < kBNTrainingUpdateOutputNum) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The output size of node bn must be " << kBNTrainingUpdateOutputNum << ", but it is "
|
|
|
|
|
<< bn_training_update_outputs.size();
|
|
|
|
|
<< bn_training_update_outputs.size() << " trace: " << trace::DumpSourceLines(node);
|
|
|
|
|
}
|
|
|
|
|
// Replace old bn outputs with new outputs
|
|
|
|
|
auto iter_batch_norm = (*equiv).find(batch_norm_var_);
|
|
|
|
|
if (iter_batch_norm == (*equiv).end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched.";
|
|
|
|
|
MS_LOG(EXCEPTION) << "The equiv map is expected to contains the batch_norm var after matched."
|
|
|
|
|
<< " trace: " << trace::DumpSourceLines(node);
|
|
|
|
|
}
|
|
|
|
|
AnfNodePtr bn = utils::cast<AnfNodePtr>(iter_batch_norm->second);
|
|
|
|
|
std::vector<AnfNodePtr> bn_outputs;
|
|
|
|
|