From e4e5c1231c485c7eb5906c9111b9e0d50d051afa Mon Sep 17 00:00:00 2001 From: wangzhe Date: Fri, 22 Jan 2021 15:41:30 +0800 Subject: [PATCH] modify layer_norm_fusion --- mindspore/lite/src/ops/layer_norm.cc | 2 +- .../lite/src/runtime/agent/npu/CMakeLists.txt | 1 + .../kernel/arm/fp32/layer_norm_fp32.cc | 3 - .../kernel/arm/int8/layer_norm_int8.cc | 3 - .../lite/tools/converter/anf_transform.cc | 6 +- .../optimizer/fusion/layer_norm_fusion.cc | 279 +++++++----------- .../optimizer/fusion/layer_norm_fusion.h | 19 +- 7 files changed, 128 insertions(+), 185 deletions(-) diff --git a/mindspore/lite/src/ops/layer_norm.cc b/mindspore/lite/src/ops/layer_norm.cc index b76fea6b0b..1a531d73df 100644 --- a/mindspore/lite/src/ops/layer_norm.cc +++ b/mindspore/lite/src/ops/layer_norm.cc @@ -81,7 +81,7 @@ int LayerNorm::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffe return RET_ERROR; } - auto val_offset = schema::CreateLayerNorm(*fbb, attr->epsilon(), attr->begin_norm_axis(), attr->begin_params_axis()); + auto val_offset = schema::CreateLayerNorm(*fbb, attr->begin_norm_axis(), attr->begin_params_axis(), attr->epsilon()); auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_LayerNorm, val_offset.o); fbb->Finish(prim_offset); return RET_OK; diff --git a/mindspore/lite/src/runtime/agent/npu/CMakeLists.txt b/mindspore/lite/src/runtime/agent/npu/CMakeLists.txt index 6971dfa3f5..7c17066497 100644 --- a/mindspore/lite/src/runtime/agent/npu/CMakeLists.txt +++ b/mindspore/lite/src/runtime/agent/npu/CMakeLists.txt @@ -14,6 +14,7 @@ add_library(hiai_ir_build SHARED IMPORTED) set_target_properties(hiai_ir_build PROPERTIES IMPORTED_LOCATION ${DDK_LIB_PATH}/libhiai_ir_build.so) add_library(npu_kernel_mid OBJECT ${NPU_RUNTIME_SRC}) +add_dependencies(npu_kernel_mid fbs_src) target_link_libraries( npu_kernel_mid hiai diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm_fp32.cc index 03bca32835..07c299ba6d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/layer_norm_fp32.cc @@ -34,9 +34,6 @@ int LayerNormCPUKernel::Init() { } int LayerNormCPUKernel::ReSize() { - param_->begin_norm_axis_ = -1; - param_->begin_params_axis_ = -1; - auto shape = in_tensors_.front()->shape(); param_->begin_norm_axis_ = param_->begin_norm_axis_ > 0 ? param_->begin_norm_axis_ : param_->begin_norm_axis_ + shape.size(); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/layer_norm_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/layer_norm_int8.cc index 4b986a0932..cebe3b4ca6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/layer_norm_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/layer_norm_int8.cc @@ -82,9 +82,6 @@ int LayerNormInt8CPUKernel::Init() { } int LayerNormInt8CPUKernel::ReSize() { - param_->begin_norm_axis_ = -1; - param_->begin_params_axis_ = -1; - auto shape = in_tensors_.front()->shape(); param_->begin_norm_axis_ = param_->begin_norm_axis_ > 0 ? param_->begin_norm_axis_ : param_->begin_norm_axis_ + shape.size(); diff --git a/mindspore/lite/tools/converter/anf_transform.cc b/mindspore/lite/tools/converter/anf_transform.cc index e15711a446..b272ff65ea 100644 --- a/mindspore/lite/tools/converter/anf_transform.cc +++ b/mindspore/lite/tools/converter/anf_transform.cc @@ -122,9 +122,6 @@ int AnfTransform::AddGraphPass(const std::shared_ptr &optim weight_format_transform_pass->SetFmkType(config->fmk); weight_format_transform_pass->SetQuantType(config->quantType); graph_pm->AddPass(weight_format_transform_pass); - auto infershape_pass = std::make_shared(); - infershape_pass->SetFmkType(config->fmk); - graph_pm->AddPass(infershape_pass); auto slice_prepose_pass = std::make_shared(); slice_prepose_pass->SetFmkType(config->fmk); graph_pm->AddPass(slice_prepose_pass); @@ -155,6 +152,9 @@ int AnfTransform::AddConstFoldPass(const std::shared_ptr &o auto update_conv2d_param_pass = std::make_shared(); update_conv2d_param_pass->SetFmkType(config->fmk); const_fold_pm->AddPass(update_conv2d_param_pass); + auto infershape_pass = std::make_shared(); + infershape_pass->SetFmkType(config->fmk); + const_fold_pm->AddPass(infershape_pass); optimizer->AddPassManager(const_fold_pm); return RET_OK; } diff --git a/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc b/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc index 68c6e243fb..69acb8dec7 100644 --- a/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.cc @@ -30,11 +30,6 @@ namespace mindspore { namespace opt { namespace { -constexpr size_t kAddInputsLength = 3; -constexpr size_t kSubInputsLength = 3; -constexpr size_t kMulInputsLength = 3; -constexpr size_t kRsqrtInputsLength = 2; -constexpr size_t kReduceInputsLength = 2; bool IsAddNode(const BaseRef &n) { if (utils::isa(n) || utils::isa(n)) { @@ -52,14 +47,6 @@ bool IsSquaredDifferenceNode(const BaseRef &n) { return false; } -bool IsReduceNode(const BaseRef &n) { - if (utils::isa(n) || utils::isa(n)) { - auto type = opt::GetCNodeType(n); - return type == schema::PrimitiveType_Reduce; - } - return false; -} - bool IsRsqrtNode(const BaseRef &n) { if (utils::isa(n) || utils::isa(n)) { auto type = opt::GetCNodeType(n); @@ -86,13 +73,11 @@ bool IsSubNode(const BaseRef &n) { } // namespace const BaseRef LayerNormFusion::DefinePattern() const { - auto mean1 = std::make_shared(IsReduceNode); - VectorRef mean1_ref = VectorRef({mean1, input_}); + VectorRef mean1_ref = VectorRef({mean1_, input_}); auto squared_diffference1 = std::make_shared(IsSquaredDifferenceNode); VectorRef squared_diffference1_ref = VectorRef({squared_diffference1, input_, mean1_ref}); auto mul1 = std::make_shared(IsMulNode); - auto mean2 = std::make_shared(IsReduceNode); - VectorRef mean2_ref = VectorRef({mean2, squared_diffference1_ref}); + VectorRef mean2_ref = VectorRef({mean2_, squared_diffference1_ref}); auto add1 = std::make_shared(IsAddNode); VectorRef add1_ref = VectorRef({add1, mean2_ref, epsilon_}); auto rsqrt1 = std::make_shared(IsRsqrtNode); @@ -109,221 +94,177 @@ const BaseRef LayerNormFusion::DefinePattern() const { return add2_ref; } -CNodePtr LayerNormFusion::CreateLayerNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, - const std::vector &shape, const float epsilon) const { - MS_EXCEPTION_IF_NULL(func_graph); +CNodePtr LayerNormFusion::CreateLayerNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, float epsilon, + int begin_norm_axis, int begin_params_axis) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(equiv != nullptr); auto layer_norm_primitive = std::make_unique(); std::unique_ptr attr = std::make_unique(); attr->epsilon = epsilon; + attr->begin_norm_axis = begin_norm_axis; + attr->begin_params_axis = begin_params_axis; layer_norm_primitive->value.type = schema::PrimitiveType_LayerNorm; layer_norm_primitive->value.value = attr.release(); auto layer_norm_cvalue = lite::PrimitiveC::Create(layer_norm_primitive.release()); auto value_node = NewValueNode(std::shared_ptr(layer_norm_cvalue)); std::vector new_node_inputs = {value_node}; auto input_node = utils::cast((*equiv)[input_]); - MS_EXCEPTION_IF_NULL(input_node); + MS_ASSERT(input_node != nullptr); new_node_inputs.push_back(input_node); auto gamma_node = utils::cast((*equiv)[gamma_]); - MS_EXCEPTION_IF_NULL(gamma_node); + MS_ASSERT(gamma_node != nullptr); new_node_inputs.push_back(gamma_node); auto beta_node = utils::cast((*equiv)[beta_]); - MS_EXCEPTION_IF_NULL(beta_node); + MS_ASSERT(beta_node != nullptr); new_node_inputs.push_back(beta_node); auto new_node = func_graph->NewCNode(new_node_inputs); return new_node; } -const AnfNodePtr LayerNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &equiv) const { - MS_ASSERT(func_graph != nullptr); - MS_ASSERT(node != nullptr); - MS_LOG(DEBUG) << "layer_norm pass"; - if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(node) != lite::RET_OK) { - lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); - return nullptr; - } - - // add2 - auto add2_cnode = node->cast(); - if (CheckIfCNodeIsNull(add2_cnode) != lite::RET_OK || CheckInputSize(add2_cnode, kAddInputsLength) != lite::RET_OK) { - return nullptr; - } - auto add2_primitivec = GetValueNode>(add2_cnode->input(0)); - MS_ASSERT(utils::isa>(add2_primitivec)); - auto add2_op = utils::cast>(add2_primitivec); - MS_ASSERT(add2_op != nullptr); - AnfNodePtr sub1_node = add2_cnode->input(2); - if (CheckIfAnfNodeIsNull(sub1_node) != lite::RET_OK) { - return nullptr; - } - - // sub1 - auto sub1_cnode = sub1_node->cast(); - if (CheckIfCNodeIsNull(sub1_cnode) != lite::RET_OK || CheckInputSize(sub1_cnode, kSubInputsLength) != lite::RET_OK) { - return nullptr; +bool LayerNormFusion::GetAxis(const CNodePtr &input_cnode, const std::vector &mean_axes, + const std::vector ¶ms_shape, int *begin_norm_axis, + int *begin_params_axis) const { + MS_ASSERT(input_node != nullptr); + MS_ASSERT(begin_norm_axis != nullptr); + MS_ASSERT(begin_params_axis != nullptr); + auto abstract = input_cnode->abstract(); + if (abstract == nullptr) { + MS_LOG(DEBUG) << "abstract of input is nullptr"; + return false; + } + if (!utils::isa(abstract)) { + MS_LOG(DEBUG) << "Abstract should be abstract tensor"; + return false; + } + auto abstract_tensor = utils::cast(abstract); + if (!utils::isa(abstract_tensor->BuildShape())) { + MS_LOG(DEBUG) << "Shape of Abstract should be ShapePtr"; + return false; + } + auto shape = utils::cast(abstract_tensor->BuildShape())->shape(); + if (mean_axes.back() + 1 != static_cast(shape.size())) { + MS_LOG(DEBUG) << "mean node is not reduce to last axis"; + return false; + } + for (size_t i = 1; i < mean_axes.size(); ++i) { + if (mean_axes[i] != mean_axes[i - 1] + 1) { + MS_LOG(DEBUG) << "mean axes is not continuous"; + return false; + } } - auto sub1_primitivec = GetValueNode>(sub1_cnode->input(0)); - MS_ASSERT(utils::isa>(sub1_primitivec)); - auto sub1_op = utils::cast>(sub1_primitivec); - MS_ASSERT(sub1_op != nullptr); - AnfNodePtr beta_node = sub1_cnode->input(1); - AnfNodePtr mul3_node = sub1_cnode->input(2); - if (CheckIfAnfNodeIsNull(beta_node) != lite::RET_OK || CheckIfAnfNodeIsNull(mul3_node) != lite::RET_OK) { - return nullptr; + // there is no need to check params_shape + *begin_norm_axis = mean_axes.front(); + *begin_params_axis = static_cast(shape.size()) - static_cast(params_shape.size()); + if (*begin_params_axis < 0) { + MS_LOG(DEBUG) << "LayerNorm begin_params_axis illegal, not fuse"; + return false; } + return true; +} +bool LayerNormFusion::CheckPattern(const EquivPtr &equiv, float *epsilon, int *begin_norm_axis, + int *begin_params_axis) const { + MS_ASSERT(equiv != nullptr); + MS_ASSERT(epsilon != nullptr); + MS_ASSERT(begin_norm_axis != nullptr); + MS_ASSERT(begin_params_axis != nullptr); // beta + auto beta_node = utils::cast((*equiv)[beta_]); + MS_ASSERT(beta_node != nullptr); if (CheckIfNodeIsParam(beta_node) != lite::RET_OK) { - return nullptr; + return false; } auto beta_param = beta_node->cast()->default_param(); auto beta_tensor = std::dynamic_pointer_cast(beta_param); auto beta_shape = beta_tensor->tensor_shape(); - - // mul3 - auto mul3_cnode = mul3_node->cast(); - if (CheckIfCNodeIsNull(mul3_cnode) != lite::RET_OK || CheckInputSize(mul3_cnode, kMulInputsLength) != lite::RET_OK) { - return nullptr; - } - auto mul3_primitivec = GetValueNode>(mul3_cnode->input(0)); - MS_ASSERT(utils::isa>(mul3_primitivec)); - auto mul3_op = utils::cast>(mul3_primitivec); - MS_ASSERT(mul3_op != nullptr); - AnfNodePtr mean1_node = mul3_cnode->input(1); - AnfNodePtr mul2_node = mul3_cnode->input(2); - if (CheckIfAnfNodeIsNull(mean1_node) != lite::RET_OK || CheckIfAnfNodeIsNull(mul2_node) != lite::RET_OK) { - return nullptr; - } - - // mul2 - auto mul2_cnode = mul2_node->cast(); - if (CheckIfCNodeIsNull(mul2_cnode) != lite::RET_OK || CheckInputSize(mul2_cnode, kMulInputsLength) != lite::RET_OK) { - return nullptr; - } - auto mul2_primitivec = GetValueNode>(mul2_cnode->input(0)); - MS_ASSERT(utils::isa>(mul2_primitivec)); - auto mul2_op = utils::cast>(mul2_primitivec); - MS_ASSERT(mul2_op != nullptr); - AnfNodePtr rsqrt_node = mul2_cnode->input(1); - AnfNodePtr gamma_node = mul2_cnode->input(2); - if (CheckIfAnfNodeIsNull(rsqrt_node) != lite::RET_OK || CheckIfAnfNodeIsNull(gamma_node) != lite::RET_OK) { - return nullptr; - } - // gamma + auto gamma_node = utils::cast((*equiv)[gamma_]); + MS_ASSERT(gamma_node != nullptr); if (CheckIfNodeIsParam(gamma_node) != lite::RET_OK) { - return nullptr; + return false; } auto gamma_param = gamma_node->cast()->default_param(); auto gamma_tensor = std::dynamic_pointer_cast(gamma_param); auto gamma_shape = gamma_tensor->tensor_shape(); - - // rsqrt - auto rsqrt_cnode = rsqrt_node->cast(); - if (CheckIfCNodeIsNull(rsqrt_cnode) != lite::RET_OK || - CheckInputSize(rsqrt_cnode, kRsqrtInputsLength) != lite::RET_OK) { - return nullptr; - } - auto rsqrt_primitivec = GetValueNode>(rsqrt_cnode->input(0)); - MS_ASSERT(utils::isa>(rsqrt_primitivec)); - auto rsqrt_op = utils::cast>(rsqrt_primitivec); - MS_ASSERT(rsqrt_op != nullptr); - AnfNodePtr add1_node = rsqrt_cnode->input(1); - if (CheckIfAnfNodeIsNull(add1_node) != lite::RET_OK) { - return nullptr; - } - - // add1 - auto add1_cnode = add1_node->cast(); - if (CheckIfCNodeIsNull(add1_cnode) != lite::RET_OK || CheckInputSize(add1_cnode, kAddInputsLength) != lite::RET_OK) { - return nullptr; - } - auto add1_primitivec = GetValueNode>(add1_cnode->input(0)); - MS_ASSERT(utils::isa>(add1_primitivec)); - auto add1_op = utils::cast>(add1_primitivec); - MS_ASSERT(add1_op != nullptr); - AnfNodePtr mean2_node = add1_cnode->input(1); - AnfNodePtr epsilon_node = add1_cnode->input(2); - if (CheckIfAnfNodeIsNull(mean2_node) != lite::RET_OK || CheckIfAnfNodeIsNull(epsilon_node) != lite::RET_OK) { - return nullptr; - } - // epsilon + auto epsilon_node = utils::cast((*equiv)[epsilon_]); + MS_ASSERT(epsilon_node != nullptr); if (CheckIfNodeIsParam(epsilon_node) != lite::RET_OK) { - // delete[] add_bias_data; - return nullptr; + return false; } auto epsilon_param = epsilon_node->cast()->default_param(); auto epsilon_tensor = std::dynamic_pointer_cast(epsilon_param); auto epsilon_data = reinterpret_cast(epsilon_tensor->tensor_addr()); auto epsilon_shape = epsilon_tensor->tensor_shape(); - // mean2 - auto mean2_cnode = mean2_node->cast(); - if (CheckIfCNodeIsNull(mean2_cnode) != lite::RET_OK || - CheckInputSize(mean2_cnode, kReduceInputsLength) != lite::RET_OK) { - return nullptr; + auto mean2_value = utils::cast((*equiv)[mean2_]); + MS_ASSERT(mean2_value != nullptr); + auto mean2_primitivec = GetValueNode>(mean2_value); + if (!utils::isa>(mean2_primitivec)) { + return false; } - auto mean2_primitivec = GetValueNode>(mean2_cnode->input(0)); - MS_ASSERT(utils::isa>(mean2_primitivec)); auto mean2_op = utils::cast>(mean2_primitivec); MS_ASSERT(mean2_op != nullptr); if (mean2_op->GetMode() != schema::ReduceMode_ReduceMean) { - return nullptr; + return false; } auto mean2_axes = mean2_op->GetAxes(); - AnfNodePtr squared_difference_node = mean2_cnode->input(1); - if (CheckIfAnfNodeIsNull(squared_difference_node) != lite::RET_OK) { - return nullptr; - } - // mean1 - auto mean1_cnode = mean1_node->cast(); - if (CheckIfCNodeIsNull(mean1_cnode) != lite::RET_OK || - CheckInputSize(mean1_cnode, kReduceInputsLength) != lite::RET_OK) { - return nullptr; + auto mean1_value = utils::cast((*equiv)[mean1_]); + MS_ASSERT(mean1_value != nullptr); + auto mean1_primitivec = GetValueNode>(mean1_value); + if (!utils::isa>(mean1_primitivec)) { + return false; } - auto mean1_primitivec = GetValueNode>(mean1_cnode->input(0)); - MS_ASSERT(utils::isa>(mean1_primitivec)); auto mean1_op = utils::cast>(mean1_primitivec); MS_ASSERT(mean1_op != nullptr); if (mean1_op->GetMode() != schema::ReduceMode_ReduceMean) { - return nullptr; + return false; } - AnfNodePtr input3_node = mean1_cnode->input(1); auto mean1_axes = mean1_op->GetAxes(); - if (CheckIfAnfNodeIsNull(input3_node) != lite::RET_OK) { - return nullptr; - } - - // verify two mean ops have same axes - if (mean1_axes.size() != mean2_axes.size()) { - return nullptr; + auto input_node = utils::cast((*equiv)[input_]); + MS_ASSERT(input_node != nullptr); + if (!utils::isa(input_node)) { + return false; } - for (size_t i = 0; i < mean1_axes.size(); ++i) { - if (mean1_axes[i] != mean2_axes[i]) { - return nullptr; - } + auto input_cnode = input_node->cast(); + if (mean1_axes != mean2_axes) { + return false; } - // verify axes size and gamma/beta size are equal if (mean1_axes.size() != gamma_shape.size() || mean1_axes.size() != beta_shape.size()) { - return nullptr; + return false; } - // verify gamma and beta have same shape - for (size_t i = 0; i < gamma_shape.size(); ++i) { - if (gamma_shape[i] != beta_shape[i]) { - return nullptr; - } + if (gamma_shape != beta_shape) { + return false; } - // verify epsilon has exactly one element - float epsilon; if (epsilon_shape.empty() || (epsilon_shape.size() == 1 && epsilon_shape[0] == 1)) { - epsilon = epsilon_data[0]; + *epsilon = epsilon_data[0]; } else { - return nullptr; + return false; } + if (!GetAxis(input_cnode, mean1_axes, gamma_shape, begin_norm_axis, begin_params_axis)) { + return false; + } + return true; +} - auto layer_norm_cnode = CreateLayerNormNode(func_graph, equiv, gamma_shape, epsilon); +const AnfNodePtr LayerNormFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_ASSERT(func_graph != nullptr); + MS_ASSERT(node != nullptr); + MS_ASSERT(equiv != nullptr); + MS_LOG(DEBUG) << "layer_norm_fusion pass"; + if (!utils::isa(node)) { + return nullptr; + } + auto add2_cnode = node->cast(); + float epsilon = 0.0f; + int begin_norm_axis = 0; + int begin_params_axis = 0; + if (!CheckPattern(equiv, &epsilon, &begin_norm_axis, &begin_params_axis)) { + return nullptr; + } + auto layer_norm_cnode = CreateLayerNormNode(func_graph, equiv, epsilon, begin_norm_axis, begin_params_axis); layer_norm_cnode->set_abstract(add2_cnode->abstract()->Clone()); layer_norm_cnode->set_fullname_with_scope("layer_norm_" + add2_cnode->fullname_with_scope()); MS_LOG(INFO) << "layernorm node:" << layer_norm_cnode->fullname_with_scope() << " fusion success"; diff --git a/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.h b/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.h index 3522f66e5e..bf0960d79c 100644 --- a/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.h +++ b/mindspore/lite/tools/optimizer/fusion/layer_norm_fusion.h @@ -31,6 +31,8 @@ class LayerNormFusion : public PatternProcessPass { explicit LayerNormFusion(const std::string &name = "layer_norm_fusion", bool multigraph = true) : PatternProcessPass(name, multigraph) { input_ = std::make_shared(); + mean1_ = std::make_shared(); + mean2_ = std::make_shared(); gamma_ = std::make_shared(); beta_ = std::make_shared(); epsilon_ = std::make_shared(); @@ -41,12 +43,17 @@ class LayerNormFusion : public PatternProcessPass { const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; private: - CNodePtr CreateLayerNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const std::vector &shape, - const float epsilon) const; - VarPtr input_; - VarPtr gamma_; - VarPtr beta_; - VarPtr epsilon_; + bool GetAxis(const CNodePtr &input_cnode, const std::vector &mean_axes, const std::vector ¶ms_shape, + int *begin_norm_axis, int *begin_params_axis) const; + bool CheckPattern(const EquivPtr &equiv, float *epsilon, int *begin_norm_axis, int *begin_params_axis) const; + CNodePtr CreateLayerNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, float epsilon, + int begin_norm_axis, int begin_params_axis) const; + VarPtr input_ = nullptr; + VarPtr mean1_ = nullptr; + VarPtr mean2_ = nullptr; + VarPtr gamma_ = nullptr; + VarPtr beta_ = nullptr; + VarPtr epsilon_ = nullptr; }; } // namespace opt } // namespace mindspore