adjust global var

pull/14175/head
xuanyue 4 years ago
parent d346a861bc
commit 33879614c5

@ -538,6 +538,8 @@ inline const PrimitivePtr kPrimTileFusion = std::make_shared<Primitive>("TileFus
inline const PrimitivePtr kPrimReduceFusion = std::make_shared<Primitive>("ReduceFusion");
inline const PrimitivePtr kPrimLayerNormFusion = std::make_shared<Primitive>("LayerNormFusion");
inline const PrimitivePtr kPrimDType = std::make_shared<Primitive>("DType");
inline const PrimitivePtr kPrimDivFusion = std::make_shared<Primitive>("DivFusion");
inline const PrimitivePtr kPrimErf = std::make_shared<Primitive>("Erf");
class DoSignaturePrimitive : public Primitive {
public:

@ -34,8 +34,6 @@ using mindspore::lite::RET_OK;
using mindspore::lite::STATUS;
namespace mindspore {
namespace opt {
inline const PrimitivePtr kPrimDivFusion = std::make_shared<Primitive>("DivFusion");
inline const PrimitivePtr kPrimErf = std::make_shared<Primitive>("Erf");
inline const PrimitivePtr kPrimMakeTupleV2 = std::make_shared<Primitive>("make_tuple");
inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("Identity");
constexpr auto kWeightFormat = "weight_format";

@ -297,7 +297,8 @@ const BaseRef OnnxLayerNormFusion::DefinePattern() const {
VectorRef add1_ref =
VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), mean2_ref, epsilon_});
VectorRef sqrt_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimSqrt>), add1_ref});
VectorRef div_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimDivFusion>), sub1_ref, sqrt_ref});
VectorRef div_ref =
VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>), sub1_ref, sqrt_ref});
VectorRef mul_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), gamma_, div_ref});
VectorRef add2_ref = VectorRef({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), mul_ref, beta_});
return add2_ref;

@ -27,8 +27,8 @@ constexpr float MUL1_y = 0.5;
// gelu(x) = 1/2 * x * [1 + erf(x / sqrt(2))]
const BaseRef OnnxGeLUFusion::DefinePattern() const {
VectorRef div_ref({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimDivFusion>), input_, div_y_});
VectorRef erf_ref({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimErf>), div_ref});
VectorRef div_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>), input_, div_y_});
VectorRef erf_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimErf>), div_ref});
VectorRef add_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), erf_ref, add_y_});
VectorRef mul1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), input_, mul1_y_});
VectorRef mul2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul1_ref, add_ref});

Loading…
Cancel
Save