!8492 expand maxmium_grad minimum_grad and dropout_grad
From: @zengzitao Reviewed-by: @gaoxiong1,@ckey_dou Signed-off-by: @ckey_doupull/8492/MERGE
commit
dbe5229c56
@ -0,0 +1,44 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for DropoutGrad"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
|
||||
|
||||
def expand_dropoutgrad(expand_info):
|
||||
"""DropoutGrad expander"""
|
||||
# get op info.
|
||||
dy_desc = expand_info['input_desc'][0]
|
||||
mask_desc = expand_info['input_desc'][1]
|
||||
keep_prob = None
|
||||
for attr in expand_info['attr']:
|
||||
if 'keep_prob' in attr:
|
||||
keep_prob = attr['keep_prob']
|
||||
if keep_prob is None:
|
||||
raise RuntimeError("keep_prob does not exist in attrs.")
|
||||
# generate a graph.
|
||||
graph_builder = builder.GraphBuilder()
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_dy = graph_builder.tensor(dy_desc['shape'], dy_desc['data_type'], dy_desc['format'])
|
||||
input_mask = graph_builder.tensor(mask_desc['shape'], mask_desc['data_type'], mask_desc['format'])
|
||||
graph_scope.set_input(input_dy, input_mask)
|
||||
r_keep_prob = graph_builder.value(input_dy.dtype, 1.0 / keep_prob, "DefaultFormat")
|
||||
# create op.
|
||||
result = graph_builder.emit('Mul', [input_dy, r_keep_prob])
|
||||
result = graph_builder.emit('Mul', [result, input_mask])
|
||||
# set graph output.
|
||||
graph_scope.set_output(result)
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
@ -0,0 +1,58 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for maximum_grad"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
|
||||
|
||||
def expand_maximumgrad(expand_info):
|
||||
"""MaximumGrad expander"""
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
input_desc_1 = expand_info['input_desc'][1]
|
||||
input_desc_2 = expand_info['input_desc'][2]
|
||||
attrs = expand_info['attr']
|
||||
grad_x = None
|
||||
grad_y = None
|
||||
for item in attrs:
|
||||
if 'grad_x' in item:
|
||||
grad_x = item['grad_x']
|
||||
if 'grad_y' in item:
|
||||
grad_y = item['grad_y']
|
||||
graph_builder = builder.GraphBuilder()
|
||||
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
input_y = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
input_dout = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
|
||||
graph_scope.set_input(input_x, input_y, input_dout)
|
||||
x_dtype = input_x.dtype
|
||||
# cal result
|
||||
ge_result = graph_builder.emit('GreaterEqual', [input_x, input_y])
|
||||
ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': x_dtype})
|
||||
dx = graph_builder.emit('Mul', [ge_result, input_dout])
|
||||
dy = graph_builder.emit('Sub', [input_dout, dx])
|
||||
|
||||
# set graph output according to grad_x and grad_y
|
||||
if grad_x and grad_y:
|
||||
graph_scope.set_output(dx, dy)
|
||||
if grad_x and not grad_y:
|
||||
graph_scope.set_output(dx)
|
||||
if grad_y and not grad_x:
|
||||
graph_scope.set_output(dy)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
@ -0,0 +1,58 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for minimum_grad"""
|
||||
from mindspore._extends.graph_kernel.model import model_builder as builder
|
||||
|
||||
|
||||
def expand_minimumgrad(expand_info):
|
||||
"""MinimumGrad expander"""
|
||||
# get op info.
|
||||
input_desc_0 = expand_info['input_desc'][0]
|
||||
input_desc_1 = expand_info['input_desc'][1]
|
||||
input_desc_2 = expand_info['input_desc'][2]
|
||||
attrs = expand_info['attr']
|
||||
grad_x = None
|
||||
grad_y = None
|
||||
for item in attrs:
|
||||
if 'grad_x' in item:
|
||||
grad_x = item['grad_x']
|
||||
if 'grad_y' in item:
|
||||
grad_y = item['grad_y']
|
||||
graph_builder = builder.GraphBuilder()
|
||||
# generate a graph.
|
||||
with graph_builder.graph_scope('main') as graph_scope:
|
||||
# create tensor input.
|
||||
input_x = graph_builder.tensor(input_desc_0['shape'], input_desc_0['data_type'], input_desc_0['format'])
|
||||
input_y = graph_builder.tensor(input_desc_1['shape'], input_desc_1['data_type'], input_desc_1['format'])
|
||||
input_dout = graph_builder.tensor(input_desc_2['shape'], input_desc_2['data_type'], input_desc_2['format'])
|
||||
graph_scope.set_input(input_x, input_y, input_dout)
|
||||
x_dtype = input_x.dtype
|
||||
|
||||
# cal result
|
||||
le_result = graph_builder.emit('LessEqual', [input_x, input_y])
|
||||
le_result = graph_builder.emit('Cast', [le_result], attrs={'dst_type': x_dtype})
|
||||
dx = graph_builder.emit('Mul', [le_result, input_dout])
|
||||
dy = graph_builder.emit('Sub', [input_dout, dx])
|
||||
|
||||
# set graph output according to grad_x and grad_y
|
||||
if grad_x and grad_y:
|
||||
graph_scope.set_output(dx, dy)
|
||||
if grad_x and not grad_y:
|
||||
graph_scope.set_output(dx)
|
||||
if grad_y and not grad_x:
|
||||
graph_scope.set_output(dy)
|
||||
|
||||
graph = graph_builder.get()[0]
|
||||
return graph
|
@ -0,0 +1,48 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.ops.operations._grad_ops as G
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
||||
|
||||
|
||||
class MaxmumGradNet(Cell):
|
||||
def __init__(self):
|
||||
super(MaxmumGradNet, self).__init__()
|
||||
self.maximum_grad = G.MaximumGrad()
|
||||
|
||||
def construct(self, x, y, dy):
|
||||
return self.maximum_grad(x, y, dy)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_maximum_grad():
|
||||
np.random.seed(0)
|
||||
input_x = np.random.normal(0, 1, [2, 3]).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, [2, 3]).astype(np.float32)
|
||||
input_dout = np.maximum(input_x, input_y).astype(np.float32)
|
||||
net = MaxmumGradNet()
|
||||
result = net(Tensor(input_x), Tensor(input_y), Tensor(input_dout))
|
||||
dx = input_dout * (input_x >= input_y)
|
||||
dy = input_dout - dx
|
||||
assert np.allclose(result[0].asnumpy(), dx, rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
||||
assert np.allclose(result[1].asnumpy(), dy, rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
@ -0,0 +1,48 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.nn import Cell
|
||||
import mindspore.ops.operations._grad_ops as G
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_graph_kernel=True, device_target="GPU")
|
||||
|
||||
|
||||
class MinmumGradNet(Cell):
|
||||
def __init__(self):
|
||||
super(MinmumGradNet, self).__init__()
|
||||
self.minimum_grad = G.MinimumGrad()
|
||||
|
||||
def construct(self, x, y, dy):
|
||||
return self.minimum_grad(x, y, dy)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_minimum_grad():
|
||||
np.random.seed(0)
|
||||
input_x = np.random.normal(0, 1, [2, 3]).astype(np.float32)
|
||||
input_y = np.random.normal(0, 1, [2, 3]).astype(np.float32)
|
||||
input_dout = np.minimum(input_x, input_y).astype(np.float32)
|
||||
net = MinmumGradNet()
|
||||
result = net(Tensor(input_x), Tensor(input_y), Tensor(input_dout))
|
||||
dx = input_dout * (input_x <= input_y)
|
||||
dy = input_dout - dx
|
||||
assert np.allclose(result[0].asnumpy(), dx, rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
||||
assert np.allclose(result[1].asnumpy(), dy, rtol=1.e-4, atol=1.e-8, equal_nan=True)
|
Loading…
Reference in new issue