expand gelu and gelugrad op

pull/8084/head
zengzitao 4 years ago
parent 813e4624ab
commit 5cfa172720

@ -15,6 +15,7 @@
"""expanders init"""
from .gelu import expand_gelu
from .gelu_grad import expand_gelugrad
from .layernorm import expand_layernorm
from .softmax import expand_softmax
from .square import expand_square

@ -16,11 +16,16 @@
from mindspore._extends.graph_kernel.model import model_builder as builder
CSVALUE = 0.044715
CSVALUE_A = 1.5957691 # 2*np.sqrt(2/np.pi)
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
ONE = 1.0
HALF = 0.5
def expand_gelu(expand_info):
"""Gelu expander"""
# cal formula are:
# gelu(x) = 0.5 * x * (1.0 + tanh(y))
# y = sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
# get op info.
input_desc = expand_info['input_desc'][0]
@ -30,35 +35,29 @@ def expand_gelu(expand_info):
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format'])
graph_scope.set_input(input_x)
dtype = input_x.dtype
if dtype == 'float16':
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
# cal tanh.
# cal y
mul_0 = graph_builder.emit('Mul', [input_x, input_x])
pow_0 = graph_builder.emit('Mul', [mul_0, input_x])
const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format'])
mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue])
tanh_res = graph_builder.emit('TensorAdd', [input_x, mul_1])
const_csvalue_sqrt_two_div_pi = graph_builder.value(
tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc['format'])
y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi])
const_csvalue_a = graph_builder.value(tanh_res.dtype, CSVALUE_A, input_desc['format'])
mul_0 = graph_builder.emit('Mul', [tanh_res, const_csvalue_a])
# cal gelu(x)
tanh_y = graph_builder.emit('Tanh', [y])
const_one = graph_builder.value(tanh_y.dtype, ONE, input_desc['format'])
const_half = graph_builder.value(tanh_y.dtype, HALF, input_desc['format'])
tanh_y_add_one = graph_builder.emit('TensorAdd', [tanh_y, const_one])
mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one])
result = graph_builder.emit('Mul', [const_half, mul_x])
const_zero = graph_builder.value(mul_0.dtype, 0.0, input_desc['format'])
mul_0_min = graph_builder.emit('Minimum', [mul_0, const_zero])
right_mul = graph_builder.emit('Exp', [mul_0_min])
mul_0_abs = graph_builder.emit('Abs', [mul_0])
const_neg_one = graph_builder.value(mul_0_abs.dtype, -1.0, input_desc['format'])
mul_0_abs_neg = graph_builder.emit('Mul', [mul_0_abs, const_neg_one])
mul_0_abs_neg_exp = graph_builder.emit('Exp', [mul_0_abs_neg])
const_one = graph_builder.value(mul_0_abs_neg_exp.dtype, 1.0, input_desc['format'])
mul_0_abs_neg_exp_add = graph_builder.emit('TensorAdd', [mul_0_abs_neg_exp, const_one])
left_mul = graph_builder.emit('RealDiv', [input_x, mul_0_abs_neg_exp_add])
result = graph_builder.emit('Mul', [left_mul, right_mul])
if dtype == 'float16':
result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'})
# set graph output.

@ -0,0 +1,92 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for gelugrad"""
from mindspore._extends.graph_kernel.model import model_builder as builder
CSVALUE = 0.044715
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
CSVALUE_TRI = 0.134141 # CSVALUE * 3
ONE = 1.0
HALF = 0.5
def expand_gelugrad(expand_info):
"""GeluGrad expander"""
# cal formula are:
# gelu_grad(dy, x) = dy * y'
# y' = 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
# tanh_para = sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
# mul_right = sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x)
# get op info.
input_desc_0 = expand_info['input_desc'][0]
input_desc_1 = expand_info['input_desc'][1]
input_desc_2 = expand_info['input_desc'][2]
graph_builder = builder.GraphBuilder()
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_dy = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
input_x = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
input_y = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
graph_scope.set_input(input_dy, input_x, input_y)
dtype = input_dy.dtype
if dtype == 'float16':
input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'})
# create some const var
const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE, input_desc_0['format'])
const_csvalue_sqrt_two_div_pi = graph_builder.value(
input_dy.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc_0['format'])
const_csvalue_tri = graph_builder.value(input_dy.dtype, CSVALUE_TRI, input_desc_0['format'])
const_one = graph_builder.value(input_dy.dtype, ONE, input_desc_0['format'])
const_half = graph_builder.value(input_dy.dtype, HALF, input_desc_0['format'])
# cal mul_right
mul_double = graph_builder.emit('Mul', [input_x, input_x])
mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double])
mul_add_one = graph_builder.emit('TensorAdd', [const_one, mul_double_mul_tri])
mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one])
# cal tanh_para
mul_triple = graph_builder.emit('Mul', [input_x, mul_double])
mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple])
mul_add_x = graph_builder.emit('TensorAdd', [input_x, mul_triple_mul_csvalue])
tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x])
# cal 0.5 * (1.0 + tanh(tahn_para))
tanh_res = graph_builder.emit('Tanh', [tanh_para])
tanh_res_add_one = graph_builder.emit('TensorAdd', [const_one, tanh_res])
half_mul_tanh_res_add_one = graph_builder.emit('Mul', [const_half, tanh_res_add_one])
# cal 0.5 * x * (1.0 - tanh(tanh_para) * tanh(tanh_para)) * mul_right
tan_res_double = graph_builder.emit('Mul', [tanh_res, tanh_res])
one_sub_tan_res_double = graph_builder.emit('Sub', [const_one, tan_res_double])
half_mul_x = graph_builder.emit('Mul', [const_half, input_x])
mul_tmp = graph_builder.emit('Mul', [half_mul_x, one_sub_tan_res_double])
mul_final = graph_builder.emit('Mul', [mul_tmp, mul_right])
# cal result
result_tmp = graph_builder.emit('TensorAdd', [half_mul_tanh_res_add_one, mul_final])
result = graph_builder.emit('Mul', [input_dy, result_tmp])
if dtype == 'float16':
result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'})
# set graph output.
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

@ -153,6 +153,7 @@ class PrimLib:
'make_tuple': Prim(CONTROL),
'ControlDepend': Prim(CONTROL),
'Assign': Prim(ELEMWISE),
'Tanh': Prim(ELEMWISE),
'@ReduceInit': Prim(ELEMWISE),
}

@ -705,6 +705,7 @@ std::unordered_set<PrimitivePtr> GetExpandOps() {
prim::kPrimSquare,
prim::kPrimBiasAdd,
prim::kPrimBiasAddGrad,
prim::kPrimGelu,
};
return expand_ops;
}
@ -731,7 +732,7 @@ std::vector<PrimitivePtr> GetFusibleOpList() {
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast,
prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect,
prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum};
prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh};
return fusible_basic_ops;
}

Loading…
Cancel
Save