parent
813e4624ab
commit
5cfa172720
@ -0,0 +1,92 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for gelugrad"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
|
||||
CSVALUE = 0.044715
|
||||
CSVALUE_SQRT_TWO_DIV_PI = 0.7978845608028564 # np.sqrt(2/np.pi)
|
||||
CSVALUE_TRI = 0.134141 # CSVALUE * 3
|
||||
ONE = 1.0
|
||||
HALF = 0.5
|
||||
|
||||
|
||||
def expand_gelugrad(expand_info):
|
||||
"""GeluGrad expander"""
|
||||
# cal formula are:
|
||||
# gelu_grad(dy, x) = dy * y'
|
||||
# y' = 0.5 * (1.0 + tanh(tanh_para)) + 0.5 * x * (1.0 - tanh(tanh_para) * tanh(para)) * mul_right
|
||||
# tanh_para = sqrt(2.0 / pi) * (x + 0.044715 * x * x * x)
|
||||
# mul_right = sqrt(2.0 / pi) * (1 + 3 * 0.044715 * x * x)
|
||||
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
input_desc_1 = expand_info['input_desc'][1]
|
||||
input_desc_2 = expand_info['input_desc'][2]
|
||||
graph_builder = builder.GraphBuilder()
|
||||
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_dy = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
input_x = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
input_y = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
|
||||
graph_scope.set_input(input_dy, input_x, input_y)
|
||||
dtype = input_dy.dtype
|
||||
if dtype == 'float16':
|
||||
input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'})
|
||||
|
||||
# create some const var
|
||||
const_csvalue = graph_builder.value(input_dy.dtype, CSVALUE, input_desc_0['format'])
|
||||
const_csvalue_sqrt_two_div_pi = graph_builder.value(
|
||||
input_dy.dtype, CSVALUE_SQRT_TWO_DIV_PI, input_desc_0['format'])
|
||||
const_csvalue_tri = graph_builder.value(input_dy.dtype, CSVALUE_TRI, input_desc_0['format'])
|
||||
const_one = graph_builder.value(input_dy.dtype, ONE, input_desc_0['format'])
|
||||
const_half = graph_builder.value(input_dy.dtype, HALF, input_desc_0['format'])
|
||||
|
||||
# cal mul_right
|
||||
mul_double = graph_builder.emit('Mul', [input_x, input_x])
|
||||
mul_double_mul_tri = graph_builder.emit('Mul', [const_csvalue_tri, mul_double])
|
||||
mul_add_one = graph_builder.emit('TensorAdd', [const_one, mul_double_mul_tri])
|
||||
mul_right = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_one])
|
||||
|
||||
# cal tanh_para
|
||||
mul_triple = graph_builder.emit('Mul', [input_x, mul_double])
|
||||
mul_triple_mul_csvalue = graph_builder.emit('Mul', [const_csvalue, mul_triple])
|
||||
mul_add_x = graph_builder.emit('TensorAdd', [input_x, mul_triple_mul_csvalue])
|
||||
tanh_para = graph_builder.emit('Mul', [const_csvalue_sqrt_two_div_pi, mul_add_x])
|
||||
|
||||
# cal 0.5 * (1.0 + tanh(tahn_para))
|
||||
tanh_res = graph_builder.emit('Tanh', [tanh_para])
|
||||
tanh_res_add_one = graph_builder.emit('TensorAdd', [const_one, tanh_res])
|
||||
half_mul_tanh_res_add_one = graph_builder.emit('Mul', [const_half, tanh_res_add_one])
|
||||
|
||||
# cal 0.5 * x * (1.0 - tanh(tanh_para) * tanh(tanh_para)) * mul_right
|
||||
tan_res_double = graph_builder.emit('Mul', [tanh_res, tanh_res])
|
||||
one_sub_tan_res_double = graph_builder.emit('Sub', [const_one, tan_res_double])
|
||||
half_mul_x = graph_builder.emit('Mul', [const_half, input_x])
|
||||
mul_tmp = graph_builder.emit('Mul', [half_mul_x, one_sub_tan_res_double])
|
||||
mul_final = graph_builder.emit('Mul', [mul_tmp, mul_right])
|
||||
|
||||
# cal result
|
||||
result_tmp = graph_builder.emit('TensorAdd', [half_mul_tanh_res_add_one, mul_final])
|
||||
result = graph_builder.emit('Mul', [input_dy, result_tmp])
|
||||
|
||||
if dtype == 'float16':
|
||||
result = graph_builder.emit('Cast', [result], attrs={'dst_type': 'float16'})
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
Loading…
Reference in new issue