|
|
|
@ -28,14 +28,14 @@
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
|
namespace {
|
|
|
|
|
void CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
|
|
|
|
|
bool CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
|
|
|
|
|
std::vector<AnfNodePtr> *bn_training_reduce_outputs) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bn_cnode);
|
|
|
|
|
if (bn_cnode->inputs().size() != kBnInputNum) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "BN node has wrong input size";
|
|
|
|
|
MS_LOG(INFO) << "FusedbatchNorm's input size less than " << kBnInputNum << ". " << bn_cnode->DebugString();
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
// All the inputs of BNTrainingReduce are from the inputs of BN
|
|
|
|
|
std::vector<AnfNodePtr> bn_training_reduce_inputs = {
|
|
|
|
|
NewValueNode(std::make_shared<Primitive>(kBNTrainingReduceOpName))};
|
|
|
|
|
bn_training_reduce_inputs.push_back(bn_cnode->input(1));
|
|
|
|
@ -45,8 +45,9 @@ void CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_info);
|
|
|
|
|
bn_training_reduce->set_kernel_info(kernel_info);
|
|
|
|
|
std::vector<size_t> bn_shape_i0 = AnfAlgo::GetPrevNodeOutputInferShape(bn_cnode, 0);
|
|
|
|
|
if (bn_shape_i0.size() != kShape4dDims) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Get shape of FusedBatchNorm fail";
|
|
|
|
|
if (bn_shape_i0.size() < kShape2dDims) {
|
|
|
|
|
MS_LOG(INFO) << "The FusedBatchNorm's first input's shape dims less than " << kShape2dDims;
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
std::vector<size_t> bn_training_reduce_shape = {bn_shape_i0[1]};
|
|
|
|
|
auto types = {kNumberTypeFloat32, kNumberTypeFloat32};
|
|
|
|
@ -56,6 +57,7 @@ void CreateOutputsOfBNTrainingReduce(const FuncGraphPtr &graph, const CNodePtr &
|
|
|
|
|
AnfAlgo::CopyNodeAttrs(bn_cnode, bn_training_reduce);
|
|
|
|
|
|
|
|
|
|
CreateMultipleOutputsOfAnfNode(graph, bn_training_reduce, kBNTrainingReduceOutputNum, bn_training_reduce_outputs);
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr CreateOutputsOfBNTrainingUpdate(const FuncGraphPtr &graph, const CNodePtr &bn_cnode,
|
|
|
|
@ -99,11 +101,15 @@ AnfNodePtr SplitFusedBatchNormForTBE(const FuncGraphPtr &func_graph, const AnfNo
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (cnode->inputs().size() < kBnInputNum) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "op[FusedBatchNorm] has less than " << kBnInputNum << " inputs.";
|
|
|
|
|
MS_LOG(INFO) << "op[FusedBatchNorm] has less than " << kBnInputNum << " inputs.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
// Create BNTrainingReduce node and get outputs of BNTrainingReduce
|
|
|
|
|
std::vector<AnfNodePtr> bn_training_reduce_outputs;
|
|
|
|
|
CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs);
|
|
|
|
|
if (!CreateOutputsOfBNTrainingReduce(func_graph, cnode, &bn_training_reduce_outputs)) {
|
|
|
|
|
MS_LOG(WARNING) << "Create BNTrainingReduce fail, quit split";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (bn_training_reduce_outputs.size() != kBN1OutputNum) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "make outputs of op BNTrainingReduce fail";
|
|
|
|
|
}
|
|
|
|
|