|
|
|
@ -31,6 +31,8 @@ class LayerNormFusion : public PatternProcessPass {
|
|
|
|
|
explicit LayerNormFusion(const std::string &name = "layer_norm_fusion", bool multigraph = true)
|
|
|
|
|
: PatternProcessPass(name, multigraph) {
|
|
|
|
|
input_ = std::make_shared<Var>();
|
|
|
|
|
mean1_ = std::make_shared<Var>();
|
|
|
|
|
mean2_ = std::make_shared<Var>();
|
|
|
|
|
gamma_ = std::make_shared<Var>();
|
|
|
|
|
beta_ = std::make_shared<Var>();
|
|
|
|
|
epsilon_ = std::make_shared<Var>();
|
|
|
|
@ -41,12 +43,17 @@ class LayerNormFusion : public PatternProcessPass {
|
|
|
|
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
CNodePtr CreateLayerNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, const std::vector<int> &shape,
|
|
|
|
|
const float epsilon) const;
|
|
|
|
|
VarPtr input_;
|
|
|
|
|
VarPtr gamma_;
|
|
|
|
|
VarPtr beta_;
|
|
|
|
|
VarPtr epsilon_;
|
|
|
|
|
bool GetAxis(const CNodePtr &input_cnode, const std::vector<int> &mean_axes, const std::vector<int> ¶ms_shape,
|
|
|
|
|
int *begin_norm_axis, int *begin_params_axis) const;
|
|
|
|
|
bool CheckPattern(const EquivPtr &equiv, float *epsilon, int *begin_norm_axis, int *begin_params_axis) const;
|
|
|
|
|
CNodePtr CreateLayerNormNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, float epsilon,
|
|
|
|
|
int begin_norm_axis, int begin_params_axis) const;
|
|
|
|
|
VarPtr input_ = nullptr;
|
|
|
|
|
VarPtr mean1_ = nullptr;
|
|
|
|
|
VarPtr mean2_ = nullptr;
|
|
|
|
|
VarPtr gamma_ = nullptr;
|
|
|
|
|
VarPtr beta_ = nullptr;
|
|
|
|
|
VarPtr epsilon_ = nullptr;
|
|
|
|
|
};
|
|
|
|
|
} // namespace opt
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|