|
|
|
@ -27,8 +27,8 @@ constexpr float MUL1_y = 0.5;
|
|
|
|
|
|
|
|
|
|
// gelu(x) = 1/2 * x * [1 + erf(x / sqrt(2))]
|
|
|
|
|
const BaseRef OnnxGeLUFusion::DefinePattern() const {
|
|
|
|
|
VectorRef div_ref({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimDivFusion>), input_, div_y_});
|
|
|
|
|
VectorRef erf_ref({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimErf>), div_ref});
|
|
|
|
|
VectorRef div_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimDivFusion>), input_, div_y_});
|
|
|
|
|
VectorRef erf_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimErf>), div_ref});
|
|
|
|
|
VectorRef add_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), erf_ref, add_y_});
|
|
|
|
|
VectorRef mul1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), input_, mul1_y_});
|
|
|
|
|
VectorRef mul2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul1_ref, add_ref});
|
|
|
|
|