diff --git a/mindspore/_extends/graph_kernel/expanders/__init__.py b/mindspore/_extends/graph_kernel/expanders/__init__.py index babdb06d6b..e32ce450c8 100644 --- a/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -15,6 +15,7 @@ """expanders init""" from .gelu import expand_gelu +from .gelu_grad import expand_gelugrad from .layernorm import expand_layernorm from .softmax import expand_softmax from .square import expand_square diff --git a/mindspore/_extends/graph_kernel/expanders/gelu.py b/mindspore/_extends/graph_kernel/expanders/gelu.py index 9c17a4e95b..d13564f76b 100644 --- a/mindspore/_extends/graph_kernel/expanders/gelu.py +++ b/mindspore/_extends/graph_kernel/expanders/gelu.py @@ -16,11 +16,16 @@ from mindspore._extends.graph_kernel.model import model_builder as builder CSVALUE = 0.044715 -CSVALUE_A = 1.5957691 # 2*np.sqrt(2/np.pi) +CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) +ONE = 1.0 +HALF = 0.5 def expand_gelu(expand_info): """Gelu expander""" + # cal formula are: + # gelu(x) = 0.5 * x * (1.0 + tanh(y)) + # y = sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) # get op info. input_desc = expand_info['input_desc'][0] @@ -30,35 +35,29 @@ def expand_gelu(expand_info): with graph_builder.graph_scope('main') as graph_scope: # create tensor input. input_x = graph_builder.tensor(input_desc['shape'], input_desc['data_type'], input_desc['format']) + graph_scope.set_input(input_x) dtype = input_x.dtype if dtype == 'float16': input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'}) - # cal tanh. + # cal y mul_0 = graph_builder.emit('Mul', [input_x, input_x]) pow_0 = graph_builder.emit('Mul', [mul_0, input_x]) const_csvalue = graph_builder.value(pow_0.dtype, CSVALUE, input_desc['format']) mul_1 = graph_builder.emit('Mul', [pow_0, const_csvalue]) tanh_res = graph_builder.emit('TensorAdd', [input_x, mul_1]) + const_csvalue_sqrt_two_div_pi = graph_builder.value( + tanh_res.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc['format']) + y = graph_builder.emit('Mul', [tanh_res, const_csvalue_sqrt_two_div_pi]) - const_csvalue_a = graph_builder.value(tanh_res.dtype, CSVALUE_A, input_desc['format']) - mul_0 = graph_builder.emit('Mul', [tanh_res, const_csvalue_a]) + # cal gelu(x) + tanh_y = graph_builder.emit('Tanh', [y]) + const_one = graph_builder.value(tanh_y.dtype, ONE, input_desc['format']) + const_half = graph_builder.value(tanh_y.dtype, HALF, input_desc['format']) + tanh_y_add_one = graph_builder.emit('TensorAdd', [tanh_y, const_one]) + mul_x = graph_builder.emit('Mul', [input_x, tanh_y_add_one]) + result = graph_builder.emit('Mul', [const_half, mul_x]) - const_zero = graph_builder.value(mul_0.dtype, 0.0, input_desc['format']) - mul_0_min = graph_builder.emit('Minimum', [mul_0, const_zero]) - right_mul = graph_builder.emit('Exp', [mul_0_min]) - - mul_0_abs = graph_builder.emit('Abs', [mul_0]) - const_neg_one = graph_builder.value(mul_0_abs.dtype, -1.0, input_desc['format']) - mul_0_abs_neg = graph_builder.emit('Mul', [mul_0_abs, const_neg_one]) - - mul_0_abs_neg_exp = graph_builder.emit('Exp', [mul_0_abs_neg]) - - const_one = graph_builder.value(mul_0_abs_neg_exp.dtype, 1.0, input_desc['format']) - mul_0_abs_neg_exp_add = graph_builder.emit('TensorAdd', [mul_0_abs_neg_exp, const_one]) - left_mul = graph_builder.emit('RealDiv', [input_x, mul_0_abs_neg_exp_add]) - - result = graph_builder.emit('Mul', [left_mul, right_mul]) if dtype == 'float16': result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'}) # set graph output. diff --git a/mindspore/_extends/graph_kernel/expanders/gelu_grad.py b/mindspore/_extends/graph_kernel/expanders/gelu_grad.py new file mode 100644 index 0000000000..7b8cb974e5 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/gelu_grad.py @@ -0,0 +1,92 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for gelugrad""" +from mindspore._extends.graph_kernel.model import model_builder as builder + +CSVALUE = 0.044715 +CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi) +CSVALUE_TRI = 0.134141 # CSVALUE * 3 +ONE = 1.0 +HALF = 0.5 + + +def expand_gelugrad(expand_info): + """GeluGrad expander""" + # cal formula are: + # gelu_grad(dy, x) = dy * y' + # y' = 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right + # tanh_para = sqrt(2.0 / pi) * (x + 0.044715 * x * x * x) + # mul_right = sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x) + + # get op info. + input_desc_0 = expand_info['input_desc'][0] + input_desc_1 = expand_info['input_desc'][1] + input_desc_2 = expand_info['input_desc'][2] + graph_builder = builder.GraphBuilder() + + # generate a graph. + with graph_builder.graph_scope('main') as graph_scope: + # create tensor input. + input_dy = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format']) + input_x = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format']) + input_y = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format']) + graph_scope.set_input(input_dy, input_x, input_y) + dtype = input_dy.dtype + if dtype == 'float16': + input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'}) + + # create some const var + const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE, input_desc_0['format']) + const_csvalue_sqrt_two_div_pi = graph_builder.value( + input_dy.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc_0['format']) + const_csvalue_tri = graph_builder.value(input_dy.dtype, CSVALUE_TRI, input_desc_0['format']) + const_one = graph_builder.value(input_dy.dtype, ONE, input_desc_0['format']) + const_half = graph_builder.value(input_dy.dtype, HALF, input_desc_0['format']) + + # cal mul_right + mul_double = graph_builder.emit('Mul', [input_x, input_x]) + mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double]) + mul_add_one = graph_builder.emit('TensorAdd', [const_one, mul_double_mul_tri]) + mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one]) + + # cal tanh_para + mul_triple = graph_builder.emit('Mul', [input_x, mul_double]) + mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple]) + mul_add_x = graph_builder.emit('TensorAdd', [input_x, mul_triple_mul_csvalue]) + tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x]) + + # cal 0.5 * (1.0 + tanh(tahn_para)) + tanh_res = graph_builder.emit('Tanh', [tanh_para]) + tanh_res_add_one = graph_builder.emit('TensorAdd', [const_one, tanh_res]) + half_mul_tanh_res_add_one = graph_builder.emit('Mul', [const_half, tanh_res_add_one]) + + # cal 0.5 * x * (1.0 - tanh(tanh_para) * tanh(tanh_para)) * mul_right + tan_res_double = graph_builder.emit('Mul', [tanh_res, tanh_res]) + one_sub_tan_res_double = graph_builder.emit('Sub', [const_one, tan_res_double]) + half_mul_x = graph_builder.emit('Mul', [const_half, input_x]) + mul_tmp = graph_builder.emit('Mul', [half_mul_x, one_sub_tan_res_double]) + mul_final = graph_builder.emit('Mul', [mul_tmp, mul_right]) + + # cal result + result_tmp = graph_builder.emit('TensorAdd', [half_mul_tanh_res_add_one, mul_final]) + result = graph_builder.emit('Mul', [input_dy, result_tmp]) + + if dtype == 'float16': + result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'}) + # set graph output. + graph_scope.set_output(result) + + graph = graph_builder.get()[0] + return graph diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index 1243c9d570..e8dd4fd15f 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -153,6 +153,7 @@ class PrimLib: 'make_tuple': Prim(CONTROL), 'ControlDepend': Prim(CONTROL), 'Assign': Prim(ELEMWISE), + 'Tanh': Prim(ELEMWISE), '@ReduceInit': Prim(ELEMWISE), } diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 666e02931e..dae293f7b7 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -705,6 +705,7 @@ std::unordered_set GetExpandOps() { prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, + prim::kPrimGelu, }; return expand_ops; } @@ -731,7 +732,7 @@ std::vector GetFusibleOpList() { prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, - prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum}; + prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh}; return fusible_basic_ops; }