!8492 expand maxmium_grad minimum_grad and dropout_grad

From: @zengzitao
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
pull/8492/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit dbe5229c56

@ -25,3 +25,6 @@ from .fused_adam import expand_fusedadam
from .fused_adam_weight_decay import expand_fusedadamweightdecay
from .reduce_mean import expand_reducemean
from .tanh_grad import expand_tanhgrad
from .maximum_grad import expand_maximumgrad
from .minimum_grad import expand_minimumgrad
from .dropout_grad import expand_dropoutgrad

@ -0,0 +1,44 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for DropoutGrad"""
from mindspore._extends.graph_kernel.model import model_builder as builder
def expand_dropoutgrad(expand_info):
"""DropoutGrad expander"""
# get op info.
dy_desc = expand_info['input_desc'][0]
mask_desc = expand_info['input_desc'][1]
keep_prob = None
for attr in expand_info['attr']:
if 'keep_prob' in attr:
keep_prob = attr['keep_prob']
if keep_prob is None:
raise RuntimeError("keep_prob does not exist in attrs.")
# generate a graph.
graph_builder = builder.GraphBuilder()
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format'])
input_mask = graph_builder.tensor(mask_desc['shape'], mask_desc['data_type'], mask_desc['format'])
graph_scope.set_input(input_dy, input_mask)
r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob, "DefaultFormat")
# create op.
result = graph_builder.emit('Mul', [input_dy, r_keep_prob])
result = graph_builder.emit('Mul', [result, input_mask])
# set graph output.
graph_scope.set_output(result)
graph = graph_builder.get()[0]
return graph

@ -0,0 +1,58 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for maximum_grad"""
from mindspore._extends.graph_kernel.model import model_builder as builder
def expand_maximumgrad(expand_info):
"""MaximumGrad expander"""
# get op info.
input_desc_0 = expand_info['input_desc'][0]
input_desc_1 = expand_info['input_desc'][1]
input_desc_2 = expand_info['input_desc'][2]
attrs = expand_info['attr']
grad_x = None
grad_y = None
for item in attrs:
if 'grad_x' in item:
grad_x = item['grad_x']
if 'grad_y' in item:
grad_y = item['grad_y']
graph_builder = builder.GraphBuilder()
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
input_y = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
input_dout = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
graph_scope.set_input(input_x, input_y, input_dout)
x_dtype = input_x.dtype
# cal result
ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y])
ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': x_dtype})
dx = graph_builder.emit('Mul', [ge_result, input_dout])
dy = graph_builder.emit('Sub', [input_dout, dx])
# set graph output according to grad_x and grad_y
if grad_x and grad_y:
graph_scope.set_output(dx, dy)
if grad_x and not grad_y:
graph_scope.set_output(dx)
if grad_y and not grad_x:
graph_scope.set_output(dy)
graph = graph_builder.get()[0]
return graph

@ -0,0 +1,58 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===========================================================================
"""generate json desc for minimum_grad"""
from mindspore._extends.graph_kernel.model import model_builder as builder
def expand_minimumgrad(expand_info):
"""MinimumGrad expander"""
# get op info.
input_desc_0 = expand_info['input_desc'][0]
input_desc_1 = expand_info['input_desc'][1]
input_desc_2 = expand_info['input_desc'][2]
attrs = expand_info['attr']
grad_x = None
grad_y = None
for item in attrs:
if 'grad_x' in item:
grad_x = item['grad_x']
if 'grad_y' in item:
grad_y = item['grad_y']
graph_builder = builder.GraphBuilder()
# generate a graph.
with graph_builder.graph_scope('main') as graph_scope:
# create tensor input.
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
input_y = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
input_dout = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
graph_scope.set_input(input_x, input_y, input_dout)
x_dtype = input_x.dtype
# cal result
le_result = graph_builder.emit('LessEqual', [input_x, input_y])
le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': x_dtype})
dx = graph_builder.emit('Mul', [le_result, input_dout])
dy = graph_builder.emit('Sub', [input_dout, dx])
# set graph output according to grad_x and grad_y
if grad_x and grad_y:
graph_scope.set_output(dx, dy)
if grad_x and not grad_y:
graph_scope.set_output(dx)
if grad_y and not grad_x:
graph_scope.set_output(dy)
graph = graph_builder.get()[0]
return graph

@ -702,9 +702,9 @@ FuncGraphPtr JsonDescToAnf(const std::string &json_desc, const std::vector<AnfNo
std::unordered_set<PrimitivePtr> GetExpandOps() {
std::unordered_set<PrimitivePtr> expand_ops = {
prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu,
prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay, prim::kPrimTanhGrad,
prim::kPrimReduceMean};
prim::kPrimSquare, prim::kPrimBiasAdd, prim::kPrimBiasAddGrad, prim::kPrimGelu,
prim::kPrimGeluGrad, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay, prim::kPrimTanhGrad,
prim::kPrimReduceMean, prim::kPrimMaximumGrad, prim::kPrimMinimumGrad};
return expand_ops;
}

@ -160,6 +160,7 @@ inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>(
inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop");
inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask");
inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask");
inline const PrimitivePtr kPrimDropoutGrad = std::make_shared<Primitive>("DropoutGrad");
inline const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
inline const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
inline const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations._grad_ops as G
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
class MaxmumGradNet(Cell):
def __init__(self):
super(MaxmumGradNet, self).__init__()
self.maximum_grad = G.MaximumGrad()
def construct(self, x, y, dy):
return self.maximum_grad(x, y, dy)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_maximum_grad():
np.random.seed(0)
input_x = np.random.normal(0, 1, [2, 3]).astype(np.float32)
input_y = np.random.normal(0, 1, [2, 3]).astype(np.float32)
input_dout = np.maximum(input_x, input_y).astype(np.float32)
net = MaxmumGradNet()
result = net(Tensor(input_x), Tensor(input_y), Tensor(input_dout))
dx = input_dout * (input_x >= input_y)
dy = input_dout - dx
assert np.allclose(result[0].asnumpy(), dx, rtol=1.e-4, atol=1.e-8, equal_nan=True)
assert np.allclose(result[1].asnumpy(), dy, rtol=1.e-4, atol=1.e-8, equal_nan=True)

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
from mindspore.nn import Cell
import mindspore.ops.operations._grad_ops as G
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
class MinmumGradNet(Cell):
def __init__(self):
super(MinmumGradNet, self).__init__()
self.minimum_grad = G.MinimumGrad()
def construct(self, x, y, dy):
return self.minimum_grad(x, y, dy)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_minimum_grad():
np.random.seed(0)
input_x = np.random.normal(0, 1, [2, 3]).astype(np.float32)
input_y = np.random.normal(0, 1, [2, 3]).astype(np.float32)
input_dout = np.minimum(input_x, input_y).astype(np.float32)
net = MinmumGradNet()
result = net(Tensor(input_x), Tensor(input_y), Tensor(input_dout))
dx = input_dout * (input_x <= input_y)
dy = input_dout - dx
assert np.allclose(result[0].asnumpy(), dx, rtol=1.e-4, atol=1.e-8, equal_nan=True)
assert np.allclose(result[1].asnumpy(), dy, rtol=1.e-4, atol=1.e-8, equal_nan=True)
Loading…
Cancel
Save