parent
813e4624ab
commit
5cfa172720
@ -0,0 +1,92 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ===========================================================================
|
||||||
|
"""generate json desc for gelugrad"""
|
||||||
|
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||||
|
|
||||||
|
CSVALUE = 0.044715
|
||||||
|
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
|
||||||
|
CSVALUE_TRI = 0.134141 # CSVALUE * 3
|
||||||
|
ONE = 1.0
|
||||||
|
HALF = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def expand_gelugrad(expand_info):
|
||||||
|
"""GeluGrad expander"""
|
||||||
|
# cal formula are:
|
||||||
|
# gelu_grad(dy, x) = dy * y'
|
||||||
|
# y' = 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
|
||||||
|
# tanh_para = sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
|
||||||
|
# mul_right = sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x)
|
||||||
|
|
||||||
|
# get op info.
|
||||||
|
input_desc_0 = expand_info['input_desc'][0]
|
||||||
|
input_desc_1 = expand_info['input_desc'][1]
|
||||||
|
input_desc_2 = expand_info['input_desc'][2]
|
||||||
|
graph_builder = builder.GraphBuilder()
|
||||||
|
|
||||||
|
# generate a graph.
|
||||||
|
with graph_builder.graph_scope('main') as graph_scope:
|
||||||
|
# create tensor input.
|
||||||
|
input_dy = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||||
|
input_x = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||||
|
input_y = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
|
||||||
|
graph_scope.set_input(input_dy, input_x, input_y)
|
||||||
|
dtype = input_dy.dtype
|
||||||
|
if dtype == 'float16':
|
||||||
|
input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'})
|
||||||
|
|
||||||
|
# create some const var
|
||||||
|
const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE, input_desc_0['format'])
|
||||||
|
const_csvalue_sqrt_two_div_pi = graph_builder.value(
|
||||||
|
input_dy.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc_0['format'])
|
||||||
|
const_csvalue_tri = graph_builder.value(input_dy.dtype, CSVALUE_TRI, input_desc_0['format'])
|
||||||
|
const_one = graph_builder.value(input_dy.dtype, ONE, input_desc_0['format'])
|
||||||
|
const_half = graph_builder.value(input_dy.dtype, HALF, input_desc_0['format'])
|
||||||
|
|
||||||
|
# cal mul_right
|
||||||
|
mul_double = graph_builder.emit('Mul', [input_x, input_x])
|
||||||
|
mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double])
|
||||||
|
mul_add_one = graph_builder.emit('TensorAdd', [const_one, mul_double_mul_tri])
|
||||||
|
mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one])
|
||||||
|
|
||||||
|
# cal tanh_para
|
||||||
|
mul_triple = graph_builder.emit('Mul', [input_x, mul_double])
|
||||||
|
mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple])
|
||||||
|
mul_add_x = graph_builder.emit('TensorAdd', [input_x, mul_triple_mul_csvalue])
|
||||||
|
tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x])
|
||||||
|
|
||||||
|
# cal 0.5 * (1.0 + tanh(tahn_para))
|
||||||
|
tanh_res = graph_builder.emit('Tanh', [tanh_para])
|
||||||
|
tanh_res_add_one = graph_builder.emit('TensorAdd', [const_one, tanh_res])
|
||||||
|
half_mul_tanh_res_add_one = graph_builder.emit('Mul', [const_half, tanh_res_add_one])
|
||||||
|
|
||||||
|
# cal 0.5 * x * (1.0 - tanh(tanh_para) * tanh(tanh_para)) * mul_right
|
||||||
|
tan_res_double = graph_builder.emit('Mul', [tanh_res, tanh_res])
|
||||||
|
one_sub_tan_res_double = graph_builder.emit('Sub', [const_one, tan_res_double])
|
||||||
|
half_mul_x = graph_builder.emit('Mul', [const_half, input_x])
|
||||||
|
mul_tmp = graph_builder.emit('Mul', [half_mul_x, one_sub_tan_res_double])
|
||||||
|
mul_final = graph_builder.emit('Mul', [mul_tmp, mul_right])
|
||||||
|
|
||||||
|
# cal result
|
||||||
|
result_tmp = graph_builder.emit('TensorAdd', [half_mul_tanh_res_add_one, mul_final])
|
||||||
|
result = graph_builder.emit('Mul', [input_dy, result_tmp])
|
||||||
|
|
||||||
|
if dtype == 'float16':
|
||||||
|
result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'})
|
||||||
|
# set graph output.
|
||||||
|
graph_scope.set_output(result)
|
||||||
|
|
||||||
|
graph = graph_builder.get()[0]
|
||||||
|
return graph
|
Loading…
Reference in new issue