From f4289d40f36f7edd99c87b98208944dab90ac8d5 Mon Sep 17 00:00:00 2001 From: chenlei_autodiff Date: Thu, 25 Mar 2021 16:41:58 +0800 Subject: [PATCH] add graph kernel expander ops. --- .../graph_kernel/expanders/__init__.py | 4 + .../graph_kernel/expanders/batchnorm.py | 132 ++++++++++++++++++ .../graph_kernel/expanders/batchnorm_grad.py | 102 ++++++++++++++ .../_extends/graph_kernel/expanders/relu.py | 30 ++++ .../graph_kernel/expanders/relu_grad.py | 32 +++++ .../_extends/graph_kernel/model/op_infer.py | 4 +- .../graph_kernel/graph_kernel_expander.cc | 4 + tests/st/ops/graph_kernel/test_batchnorm.py | 84 +++++++++++ .../ops/graph_kernel/test_batchnorm_grad.py | 87 ++++++++++++ tests/st/ops/graph_kernel/test_relu.py | 61 ++++++++ tests/st/ops/graph_kernel/test_relu_grad.py | 62 ++++++++ 11 files changed, 601 insertions(+), 1 deletion(-) create mode 100644 mindspore/_extends/graph_kernel/expanders/batchnorm.py create mode 100644 mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py create mode 100644 mindspore/_extends/graph_kernel/expanders/relu.py create mode 100644 mindspore/_extends/graph_kernel/expanders/relu_grad.py create mode 100644 tests/st/ops/graph_kernel/test_batchnorm.py create mode 100644 tests/st/ops/graph_kernel/test_batchnorm_grad.py create mode 100644 tests/st/ops/graph_kernel/test_relu.py create mode 100644 tests/st/ops/graph_kernel/test_relu_grad.py diff --git a/mindspore/_extends/graph_kernel/expanders/__init__.py b/mindspore/_extends/graph_kernel/expanders/__init__.py index a63336ab19..084e9d676f 100644 --- a/mindspore/_extends/graph_kernel/expanders/__init__.py +++ b/mindspore/_extends/graph_kernel/expanders/__init__.py @@ -21,6 +21,8 @@ from .clip_by_norm_no_div_sum import ClipByNormNoDivSum from .dropout_grad import DropoutGrad from .fused_adam import FusedAdam from .fused_adam_weight_decay import FusedAdamWeightDecay +from .batchnorm import BatchNorm +from .batchnorm_grad import BatchNormGrad from .gelu import GeLU from .gelu_grad import GeLUGrad from .gkdropout import GkDropout @@ -31,6 +33,8 @@ from .logsoftmax_grad import LogSoftmaxGrad from .maximum_grad import MaximumGrad from .minimum_grad import MinimumGrad from .reduce_mean import ReduceMean +from .relu import ReLU +from .relu_grad import ReluGrad from .softmax import Softmax from .sigmoid import Sigmoid from .sigmoid_grad import SigmoidGrad diff --git a/mindspore/_extends/graph_kernel/expanders/batchnorm.py b/mindspore/_extends/graph_kernel/expanders/batchnorm.py new file mode 100644 index 0000000000..8fc2b65d17 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/batchnorm.py @@ -0,0 +1,132 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for BatchNorm""" +from mindspore._extends.graph_kernel.model.model import DataFormat as DF +from ._utils import Expander, ExpanderInfoValidator as VLD + + +@VLD.add_format(DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) +@VLD.add_format(DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) +@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) +@VLD.check_attrs('is_training', 'momentum', 'epsilon') +class BatchNorm(Expander): + """BatchNorm expander""" + def _expand(self, graph_builder): + # get op info + input_x = self.inputs[0] + input_scale = self.inputs[1] + input_offset = self.inputs[2] + input_mean = self.inputs[3] + input_variance = self.inputs[4] + epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'], input_scale.data_format) + + if self.attrs['is_training']: + reduce_axis = () + shape_x = input_x.shape + if input_x.data_format == "NHWC": + reduce_axis = (0, 1, 2) + num = shape_x[0] * shape_x[1] * shape_x[2] + else: + reduce_axis = (0, 2, 3) + num = shape_x[0] * shape_x[2] * shape_x[3] + num_rec = 1.0 / num + num_rec_v = graph_builder.value(input_scale.dtype, num_rec, input_scale.data_format) + + # compute mean value of input_x + mean_sum = graph_builder.emit( + 'ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) + mean_muls = graph_builder.emit('Mul', [mean_sum, num_rec_v]) + + # compute variance of input_x + if not input_x.data_format == "NHWC": + mean_muls_expand = graph_builder.emit('ExpandDims', [mean_muls], attrs={'axis': 1}) + mean_muls_expand = graph_builder.emit('ExpandDims', [mean_muls_expand], attrs={'axis': 2}) + else: + mean_muls_expand = mean_muls + var_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand]) + var_mul = graph_builder.emit('Mul', [var_sub, var_sub]) + var_sum = graph_builder.emit('ReduceSum', [var_mul], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) + var_mul = graph_builder.emit('Mul', [var_sum, num_rec_v]) + + # y_sqrt_rec means 1 / sqrt(variance + epsilon), which is calculated in backward pass + scalar_one = 1.0 + scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one, input_scale.data_format) + y_add = graph_builder.emit('Add', [var_mul, epsilon_v]) + y_sqrt = graph_builder.emit('Sqrt', [y_add]) + y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt]) + + # compute res_y + tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand]) + if not input_x.data_format == "NHWC": + y_sqrt_rec_expand = graph_builder.emit('ExpandDims', [y_sqrt_rec], attrs={'axis': 1}) + y_sqrt_rec_expand = graph_builder.emit('ExpandDims', [y_sqrt_rec_expand], attrs={'axis': 2}) + else: + y_sqrt_rec_expand = y_sqrt_rec + y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand]) + if not input_x.data_format == "NHWC": + input_scale_expand = graph_builder.emit('ExpandDims', [input_scale], attrs={'axis': 1}) + input_scale_expand = graph_builder.emit('ExpandDims', [input_scale_expand], attrs={'axis': 2}) + else: + input_scale_expand = input_scale + res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm]) + if not input_x.data_format == "NHWC": + input_offset_expand = graph_builder.emit('ExpandDims', [input_offset], attrs={'axis': 1}) + input_offset_expand = graph_builder.emit('ExpandDims', [input_offset_expand], attrs={'axis': 2}) + else: + input_offset_expand = input_offset + res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand]) + + # compute mean_res + momentum_sub = scalar_one - self.attrs['momentum'] + momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub, input_scale.data_format) + new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean]) + momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'], input_scale.data_format) + current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls]) + updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp]) + mean_res = graph_builder.emit( + 'InplaceAssign', [input_mean, updated_moving_mean, updated_moving_mean], attrs={'fake_output': True}) + + # variance_res is calculated by sample variance, and need to multiply by num / (num - 1) + var_num = float(num) / (num - 1) + var_num_v = graph_builder.value(input_scale.dtype, var_num, input_scale.data_format) + var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul]) + new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance]) + current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update]) + updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp]) + variance_res = graph_builder.emit( + 'InplaceAssign', [input_variance, updated_moving_variance, updated_moving_variance], + attrs={'fake_output': True}) + + # compute reverse, just return a C shape tensor + reserve = graph_builder.emit('Add', [input_offset, scalar_one_v]) + return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec, reserve + # infer mode + if not input_x.data_format == "NHWC": + input_mean = graph_builder.emit('ExpandDims', [input_mean], attrs={'axis': 1}) + input_mean = graph_builder.emit('ExpandDims', [input_mean], attrs={'axis': 2}) + input_scale = graph_builder.emit('ExpandDims', [input_scale], attrs={'axis': 1}) + input_scale = graph_builder.emit('ExpandDims', [input_scale], attrs={'axis': 2}) + input_offset = graph_builder.emit('ExpandDims', [input_offset], attrs={'axis': 1}) + input_offset = graph_builder.emit('ExpandDims', [input_offset], attrs={'axis': 2}) + x_sub = graph_builder.emit('Sub', [input_x, input_mean]) + x_sub_mul = graph_builder.emit('Mul', [input_scale, x_sub]) + var_add = graph_builder.emit('Add', [epsilon_v, input_variance]) + var_add_sqrt = graph_builder.emit('Sqrt', [var_add]) + if not input_x.data_format == "NHWC": + var_add_sqrt = graph_builder.emit('ExpandDims', [var_add_sqrt], attrs={'axis': 1}) + var_add_sqrt = graph_builder.emit('ExpandDims', [var_add_sqrt], attrs={'axis': 2}) + x_div = graph_builder.emit('RealDiv', [x_sub_mul, var_add_sqrt]) + res_y = graph_builder.emit('Add', [input_offset, x_div]) + return res_y, var_add, var_add, var_add, var_add diff --git a/mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py b/mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py new file mode 100644 index 0000000000..ee498b37ac --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/batchnorm_grad.py @@ -0,0 +1,102 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for BatchNormGrad""" +from mindspore._extends.graph_kernel.model.model import DataFormat as DF +from ._utils import Expander, ExpanderInfoValidator as VLD + +@VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) +@VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) +@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT) +@VLD.check_attrs('is_training', 'epsilon') +class BatchNormGrad(Expander): + """BatchNormGrad expander""" + def _expand(self, graph_builder): + # get op info + input_dy = self.inputs[0] + input_x = self.inputs[1] + input_scale = self.inputs[2] + input_save_mean = self.inputs[3] + input_save_inv_variance = self.inputs[4] + + reduce_axis = () + shape_x = input_x.shape + if input_x.data_format == "NHWC": + reduce_axis = (0, 1, 2) + num = shape_x[0] * shape_x[1] * shape_x[2] + else: + reduce_axis = (0, 2, 3) + num = shape_x[0] * shape_x[2] * shape_x[3] + ori_type = input_x.dtype + if ori_type == 'float16': + input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'}) + if input_dy.dtype == 'float16': + input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'}) + num_rec = -1.0 / num + num_rec_v = graph_builder.value(input_scale.dtype, num_rec, input_scale.data_format) + dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) + + # in training input_save_inv_variance means 1 / sqrt(variance + epsilon), which is calculated in forward pass + if self.attrs['is_training']: + inv_variance = input_save_inv_variance + else: + epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'], input_scale.data_format) + var_add = graph_builder.emit('Add', [input_save_inv_variance, epsilon_v]) + sqrt_var_eps = graph_builder.emit('Sqrt', [var_add]) + scalar_one = 1.0 + scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one, input_scale.data_format) + inv_variance = graph_builder.emit('RealDiv', [scalar_one_v, sqrt_var_eps]) + + # compute dgamma + if not input_x.data_format == "NHWC": + input_save_mean = graph_builder.emit('ExpandDims', [input_save_mean], attrs={'axis': 1}) + input_save_mean = graph_builder.emit('ExpandDims', [input_save_mean], attrs={'axis': 2}) + inv_variance = graph_builder.emit('ExpandDims', [inv_variance], attrs={'axis': 1}) + inv_variance = graph_builder.emit('ExpandDims', [inv_variance], attrs={'axis': 2}) + input_scale = graph_builder.emit('ExpandDims', [input_scale], attrs={'axis': 1}) + input_scale = graph_builder.emit('ExpandDims', [input_scale], attrs={'axis': 2}) + x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean]) + x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance]) + dgamma_param = graph_builder.emit('Mul', [input_dy, x_div]) + dgamma = graph_builder.emit( + 'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) + + # compute dx + if self.attrs['is_training']: + tmp_b = graph_builder.emit('Mul', [num_rec_v, dbeta]) + if not input_x.data_format == "NHWC": + dgamma_expand = graph_builder.emit('ExpandDims', [dgamma], attrs={'axis': 1}) + dgamma_expand = graph_builder.emit('ExpandDims', [dgamma_expand], attrs={'axis': 2}) + tmp_b = graph_builder.emit('ExpandDims', [tmp_b], attrs={'axis': 1}) + tmp_b = graph_builder.emit('ExpandDims', [tmp_b], attrs={'axis': 2}) + else: + dgamma_expand = dgamma + x_sub_mean_dgamma_mul = graph_builder.emit('Mul', [x_div, dgamma_expand]) + tmp_c = graph_builder.emit('Mul', [num_rec_v, x_sub_mean_dgamma_mul]) + tmp_ab_add = graph_builder.emit('Add', [input_dy, tmp_b]) + tmp_abc_add = graph_builder.emit('Add', [tmp_ab_add, tmp_c]) + gamma_mul = graph_builder.emit('Mul', [input_scale, tmp_abc_add]) + dx = graph_builder.emit('Mul', [inv_variance, gamma_mul]) + else: + y_scale = graph_builder.emit('Mul', [input_scale, input_dy]) + dx = graph_builder.emit('Mul', [inv_variance, y_scale]) + if ori_type == 'float16': + dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'}) + + # set output tensors' data_format + dx.data_format = self.outputs[0]['format'] + dgamma.data_format = self.outputs[1]['format'] + dbeta.data_format = self.outputs[2]['format'] + + return dx, dgamma, dbeta diff --git a/mindspore/_extends/graph_kernel/expanders/relu.py b/mindspore/_extends/graph_kernel/expanders/relu.py new file mode 100644 index 0000000000..18b8ed6368 --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/relu.py @@ -0,0 +1,30 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for relu""" +from ._utils import Expander + + +class ReLU(Expander): + """ReLU expander""" + + def _expand(self, graph_builder): + input_x = self.inputs[0] + + const_zero = graph_builder.value(input_x.dtype, 0) + ge_result = graph_builder.emit('Greater', [input_x, const_zero]) + ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype}) + result = graph_builder.emit('Mul', [ge_result, input_x]) + + return result diff --git a/mindspore/_extends/graph_kernel/expanders/relu_grad.py b/mindspore/_extends/graph_kernel/expanders/relu_grad.py new file mode 100644 index 0000000000..d2e7f7406e --- /dev/null +++ b/mindspore/_extends/graph_kernel/expanders/relu_grad.py @@ -0,0 +1,32 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""generate json desc for relu_grad""" +from ._utils import Expander, ExpanderInfoValidator as VLD + + +@VLD.check_all_formats_same +class ReluGrad(Expander): + """ReLU expander""" + + def _expand(self, graph_builder): + input_x = self.inputs[0] + input_y = self.inputs[1] + + const_zero = graph_builder.value(input_y.dtype, 0) + ge_result = graph_builder.emit('Greater', [input_y, const_zero]) + ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype}) + result = graph_builder.emit('Mul', [ge_result, input_x]) + + return result diff --git a/mindspore/_extends/graph_kernel/model/op_infer.py b/mindspore/_extends/graph_kernel/model/op_infer.py index 6738b29924..84f08b13f9 100644 --- a/mindspore/_extends/graph_kernel/model/op_infer.py +++ b/mindspore/_extends/graph_kernel/model/op_infer.py @@ -176,7 +176,9 @@ class Reshape(_Reshape): class ExpandDims(_Reshape): def _infer_shape(self): - return list(self.inputs[0].shape).insert(self.attrs["axis"], 1) + shape = list(self.inputs[0].shape) + shape.insert(self.attrs["axis"], 1) + return shape class Cast(_Elemwise): diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc index 2b3d08bacb..4148a6af93 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_expander.cc @@ -52,6 +52,8 @@ std::unordered_set GetExpandOps() { prim::kPrimGeLU, prim::kPrimFusedAdam, prim::kPrimFusedAdamWeightDecay, + prim::kPrimBatchNorm, + prim::kPrimBatchNormGrad, prim::kPrimReduceMean, prim::kPrimMaximumGrad, prim::kPrimMinimumGrad, @@ -60,6 +62,8 @@ std::unordered_set GetExpandOps() { prim::kPrimSoftmax, prim::kPrimLayerNorm, prim::kPrimLayerNormGrad, + prim::kPrimRelu, + prim::kPrimReluGrad, prim::kPrimSigmoid, prim::kPrimSigmoidGrad, prim::kPrimSigmoidCrossEntropyWithLogits, diff --git a/tests/st/ops/graph_kernel/test_batchnorm.py b/tests/st/ops/graph_kernel/test_batchnorm.py new file mode 100644 index 0000000000..791b9fd34d --- /dev/null +++ b/tests/st/ops/graph_kernel/test_batchnorm.py @@ -0,0 +1,84 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from mindspore.ops import operations as P + + +class Net(nn.Cell): + def __init__(self, input_scale, input_bias, input_mean, input_variance, is_training): + super(Net, self).__init__() + self.fused_bn_ex = P.BatchNorm(is_training=is_training, epsilon=1e-5, momentum=0.9) + self.scale = Parameter(input_scale, name='scale') + self.bias = Parameter(input_bias, name='b') + self.mean = Parameter(input_mean, name='mean') + self.variance = Parameter(input_variance, name='variance') + def construct(self, input_x): + return self.fused_bn_ex(input_x, self.scale, self.bias, self.mean, self.variance) + + +def get_output(x, weight, bias, moving_mean, moving_var, is_training, enable_graph_kernel=False): + if enable_graph_kernel: + context.set_context(enable_graph_kernel=True) + net = Net(Tensor(weight), Tensor(bias), Tensor(moving_mean), Tensor(moving_var), is_training) + output = net(Tensor(x)) + return output, net.mean, net.variance + + +def test_bn_train(): + x = np.random.normal(0, 1, [1, 2, 4, 4]).astype(np.float32) + weight = np.random.normal(0, 1, [2,]).astype(np.float32) + bias = np.random.normal(0, 1, [2,]).astype(np.float32) + moving_mean = np.random.normal(0, 1, [2,]).astype(np.float32) + moving_var = np.random.normal(0, 1, [2,]).astype(np.float32) + + train_expect = get_output(x, weight, bias, moving_mean, moving_var, True, False) + train_output = get_output(x, weight, bias, moving_mean, moving_var, True, True) + + assert np.allclose(train_expect[0][0].asnumpy(), train_output[0][0].asnumpy(), 0.0001, 0.0001) + assert np.allclose(train_expect[0][3].asnumpy(), train_output[0][3].asnumpy(), 0.0001, 0.0001) + assert np.allclose(train_expect[0][4].asnumpy(), train_output[0][4].asnumpy(), 0.0001, 0.0001) + assert np.allclose(train_expect[1].data.asnumpy(), train_output[1].data.asnumpy(), 0.0001, 0.0001) + assert np.allclose(train_expect[2].data.asnumpy(), train_output[2].data.asnumpy(), 0.0001, 0.0001) + +def test_bn_infer(): + x = np.random.normal(5, 1, [1, 2, 4, 4]).astype(np.float32) + weight = np.random.normal(5, 1, [2,]).astype(np.float32) + bias = np.random.normal(5, 1, [2,]).astype(np.float32) + moving_mean = np.random.normal(5, 1, [2,]).astype(np.float32) + moving_var = np.random.normal(5, 1, [2,]).astype(np.float32) + + infer_expect = get_output(x, weight, bias, moving_mean, moving_var, False, False) + infer_output = get_output(x, weight, bias, moving_mean, moving_var, False, True) + + assert np.allclose(infer_expect[0][0].asnumpy(), infer_output[0][0].asnumpy(), 0.0001, 0.0001) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_bn_train_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_bn_train() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_bn_infer_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_bn_infer() diff --git a/tests/st/ops/graph_kernel/test_batchnorm_grad.py b/tests/st/ops/graph_kernel/test_batchnorm_grad.py new file mode 100644 index 0000000000..7c71370cb7 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_batchnorm_grad.py @@ -0,0 +1,87 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +from mindspore.ops.operations import _grad_ops as G + + +class Net(nn.Cell): + def __init__(self, is_training): + super(Net, self).__init__() + self.fused_bn_grad_ex = G.BatchNormGrad(is_training=is_training, epsilon=1e-5) + + def construct(self, input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse): + return self.fused_bn_grad_ex( + input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse) + + +def get_output(input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse, + is_training, enable_graph_kernel=False): + if enable_graph_kernel: + context.set_context(enable_graph_kernel=True) + net = Net(is_training) + output = net(input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse) + return output + +def test_bn_grad_train(): + input_dy = Tensor(np.random.normal(5, 1, [1, 2, 4, 4]).astype(np.float32)) + input_x = Tensor(np.random.normal(5, 1, [1, 2, 4, 4]).astype(np.float32)) + input_scale = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32)) + input_save_mean = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32)) + input_save_inv_variance = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32)) + input_reverse = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32)) + + expect = get_output( + input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse, True, False) + output = get_output( + input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse, True, True) + + assert np.allclose(expect[0].asnumpy(), output[0].asnumpy(), 0.0001, 0.0001) + assert np.allclose(expect[1].asnumpy(), output[1].asnumpy(), 0.0001, 0.0001) + assert np.allclose(expect[2].asnumpy(), output[2].asnumpy(), 0.0001, 0.0001) + +def test_bn_grad_infer(): + input_dy = Tensor(np.random.normal(5, 1, [1, 2, 4, 4]).astype(np.float32)) + input_x = Tensor(np.random.normal(5, 1, [1, 2, 4, 4]).astype(np.float32)) + input_scale = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32)) + input_save_mean = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32)) + input_save_inv_variance = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32)) + input_reverse = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32)) + + expect = get_output( + input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse, False, False) + output = get_output( + input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse, False, True) + + assert np.allclose(expect[0].asnumpy(), output[0].asnumpy(), 0.0001, 0.0001) + assert np.allclose(expect[1].asnumpy(), output[1].asnumpy(), 0.0001, 0.0001) + assert np.allclose(expect[2].asnumpy(), output[2].asnumpy(), 0.0001, 0.0001) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_bn_grad_train_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_bn_grad_train() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_bn_grad_infer_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_bn_grad_train() diff --git a/tests/st/ops/graph_kernel/test_relu.py b/tests/st/ops/graph_kernel/test_relu.py new file mode 100644 index 0000000000..c351d202c1 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_relu.py @@ -0,0 +1,61 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.relu = P.ReLU() + + def construct(self, x): + return self.relu(x) + + +def get_output(x, enable_graph_kernel=False): + if enable_graph_kernel: + context.set_context(enable_graph_kernel=True) + net = Net() + output = net(x) + return output + + +def test_relu(shape, dtype): + x = Tensor(np.random.normal(0, 10, shape).astype(dtype)) + expect = get_output(x, False) + output = get_output(x, True) + + expect_np = expect.asnumpy().copy() + output_np = output.asnumpy().copy() + + assert np.allclose(expect_np, output_np, 0.0001, 0.0001) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_relu((4, 3), np.int32) + test_relu((12, 1), np.float16) + +def test_relu_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_relu((4, 3), np.int32) + test_relu((12, 1), np.float16) diff --git a/tests/st/ops/graph_kernel/test_relu_grad.py b/tests/st/ops/graph_kernel/test_relu_grad.py new file mode 100644 index 0000000000..86cb5f23a0 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_relu_grad.py @@ -0,0 +1,62 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.relu_grad = G.ReluGrad() + + def construct(self, y_backprop, x): + return self.relu_grad(y_backprop, x) + + +def get_output(y_backprop, x, enable_graph_kernel=False): + if enable_graph_kernel: + context.set_context(enable_graph_kernel=True) + net = Net() + output = net(y_backprop, x) + return output + + +def test_relu_grad(shape1, shape2, dtype): + x = Tensor(np.random.normal(0, 10, shape1).astype(dtype)) + y_backprop = Tensor(np.random.normal(0, 10, shape2).astype(dtype)) + expect = get_output(y_backprop, x, False) + output = get_output(y_backprop, x, True) + + expect_np = expect.asnumpy().copy() + output_np = output.asnumpy().copy() + + assert np.allclose(expect_np, output_np, 0.0001, 0.0001) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_grad_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_relu_grad((4, 3), (4, 3), np.int32) + test_relu_grad((12, 1), (12, 1), np.float16) + +def test_relu_grad_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_relu_grad((4, 3), (4, 3), np.int32) + test_relu_grad((12, 1), (12, 1), np.float16)