|
|
|
@ -25,6 +25,7 @@ namespace opt {
|
|
|
|
|
namespace {
|
|
|
|
|
const std::vector<int> kOutputIndex{0, 3, 4, 5};
|
|
|
|
|
constexpr size_t kBatchNormRealOutputNum = 3;
|
|
|
|
|
constexpr size_t kBatchNormRealInputNum = 3;
|
|
|
|
|
|
|
|
|
|
bool CompareTupleGetitem(const AnfNodePtr &n1, const AnfNodePtr &n2) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(n1);
|
|
|
|
@ -56,6 +57,9 @@ bool GetBatchNormOutputs(const FuncGraphPtr &func_graph, const AnfNodePtr &bn, s
|
|
|
|
|
for (const auto &node_index : manager->node_users()[bn]) {
|
|
|
|
|
AnfNodePtr output = node_index.first;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output);
|
|
|
|
|
if (!IsPrimitiveCNode(output, prim::kPrimTupleGetItem)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto tuple_getiterm_cnode = output->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getiterm_cnode);
|
|
|
|
|
auto index_node = tuple_getiterm_cnode->input(kInputNodeOutputIndexInTupleGetItem);
|
|
|
|
@ -77,7 +81,10 @@ AnfNodePtr CreateBNTrainingReduce(const FuncGraphPtr &func_graph, const AnfNodeP
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bn);
|
|
|
|
|
auto bn_cnode = bn->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bn_cnode);
|
|
|
|
|
CheckCNodeInputSize(bn_cnode, kBatchNormInputNum + 1);
|
|
|
|
|
if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than "
|
|
|
|
|
<< kBatchNormRealInputNum + 1;
|
|
|
|
|
}
|
|
|
|
|
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
|
|
|
|
|
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName)), bn_cnode->input(1)};
|
|
|
|
|
auto bn_training_reduce = func_graph->NewCNode(bn_training_reduce_inputs);
|
|
|
|
@ -100,7 +107,10 @@ AnfNodePtr CreateBNTrainingUpdateV2(const FuncGraphPtr &func_graph, const AnfNod
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bn);
|
|
|
|
|
auto bn_cnode = bn->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bn_cnode);
|
|
|
|
|
CheckCNodeInputSize(bn_cnode, kBatchNormInputNum + 1);
|
|
|
|
|
if (bn_cnode->inputs().size() < kBatchNormRealInputNum + 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The input size of node " + bn_cnode->DebugString() + " is less than "
|
|
|
|
|
<< kBatchNormRealInputNum + 1;
|
|
|
|
|
}
|
|
|
|
|
if (bn_training_reduce_outputs.size() != kBNTrainingReduceOutputNum) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The output size of node bn_training_reduce must be " << kBNTrainingReduceOutputNum
|
|
|
|
|
<< ", but it is " << bn_training_reduce_outputs.size();
|
|
|
|
@ -164,7 +174,8 @@ const AnfNodePtr BatchNormBertFission::Process(const FuncGraphPtr &func_graph, c
|
|
|
|
|
(void)manager->Replace(output, bn_training_update_v2_outputs[output_index]);
|
|
|
|
|
output_index++;
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
// Return the new node for control depends.
|
|
|
|
|
return bn_training_update_v2;
|
|
|
|
|
}
|
|
|
|
|
} // namespace opt
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|