parent
6e997ad3fc
commit
f4289d40f3
@ -0,0 +1,132 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for BatchNorm"""
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
@VLD.add_format(DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
||||
@VLD.add_format(DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
||||
@VLD.check_attrs('is_training', 'momentum', 'epsilon')
|
||||
class BatchNorm(Expander):
|
||||
"""BatchNorm expander"""
|
||||
def _expand(self, graph_builder):
|
||||
# get op info
|
||||
input_x = self.inputs[0]
|
||||
input_scale = self.inputs[1]
|
||||
input_offset = self.inputs[2]
|
||||
input_mean = self.inputs[3]
|
||||
input_variance = self.inputs[4]
|
||||
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'], input_scale.data_format)
|
||||
|
||||
if self.attrs['is_training']:
|
||||
reduce_axis = ()
|
||||
shape_x = input_x.shape
|
||||
if input_x.data_format == "NHWC":
|
||||
reduce_axis = (0, 1, 2)
|
||||
num = shape_x[0] * shape_x[1] * shape_x[2]
|
||||
else:
|
||||
reduce_axis = (0, 2, 3)
|
||||
num = shape_x[0] * shape_x[2] * shape_x[3]
|
||||
num_rec = 1.0 / num
|
||||
num_rec_v = graph_builder.value(input_scale.dtype, num_rec, input_scale.data_format)
|
||||
|
||||
# compute mean value of input_x
|
||||
mean_sum = graph_builder.emit(
|
||||
'ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
||||
mean_muls = graph_builder.emit('Mul', [mean_sum, num_rec_v])
|
||||
|
||||
# compute variance of input_x
|
||||
if not input_x.data_format == "NHWC":
|
||||
mean_muls_expand = graph_builder.emit('ExpandDims', [mean_muls], attrs={'axis': 1})
|
||||
mean_muls_expand = graph_builder.emit('ExpandDims', [mean_muls_expand], attrs={'axis': 2})
|
||||
else:
|
||||
mean_muls_expand = mean_muls
|
||||
var_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
|
||||
var_mul = graph_builder.emit('Mul', [var_sub, var_sub])
|
||||
var_sum = graph_builder.emit('ReduceSum', [var_mul], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
||||
var_mul = graph_builder.emit('Mul', [var_sum, num_rec_v])
|
||||
|
||||
# y_sqrt_rec means 1 / sqrt(variance + epsilon), which is calculated in backward pass
|
||||
scalar_one = 1.0
|
||||
scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one, input_scale.data_format)
|
||||
y_add = graph_builder.emit('Add', [var_mul, epsilon_v])
|
||||
y_sqrt = graph_builder.emit('Sqrt', [y_add])
|
||||
y_sqrt_rec = graph_builder.emit('RealDiv', [scalar_one_v, y_sqrt])
|
||||
|
||||
# compute res_y
|
||||
tmp_sub = graph_builder.emit('Sub', [input_x, mean_muls_expand])
|
||||
if not input_x.data_format == "NHWC":
|
||||
y_sqrt_rec_expand = graph_builder.emit('ExpandDims', [y_sqrt_rec], attrs={'axis': 1})
|
||||
y_sqrt_rec_expand = graph_builder.emit('ExpandDims', [y_sqrt_rec_expand], attrs={'axis': 2})
|
||||
else:
|
||||
y_sqrt_rec_expand = y_sqrt_rec
|
||||
y_norm = graph_builder.emit('Mul', [tmp_sub, y_sqrt_rec_expand])
|
||||
if not input_x.data_format == "NHWC":
|
||||
input_scale_expand = graph_builder.emit('ExpandDims', [input_scale], attrs={'axis': 1})
|
||||
input_scale_expand = graph_builder.emit('ExpandDims', [input_scale_expand], attrs={'axis': 2})
|
||||
else:
|
||||
input_scale_expand = input_scale
|
||||
res_y_mul = graph_builder.emit('Mul', [input_scale_expand, y_norm])
|
||||
if not input_x.data_format == "NHWC":
|
||||
input_offset_expand = graph_builder.emit('ExpandDims', [input_offset], attrs={'axis': 1})
|
||||
input_offset_expand = graph_builder.emit('ExpandDims', [input_offset_expand], attrs={'axis': 2})
|
||||
else:
|
||||
input_offset_expand = input_offset
|
||||
res_y = graph_builder.emit('Add', [res_y_mul, input_offset_expand])
|
||||
|
||||
# compute mean_res
|
||||
momentum_sub = scalar_one - self.attrs['momentum']
|
||||
momentum_v_sub = graph_builder.value(input_scale.dtype, momentum_sub, input_scale.data_format)
|
||||
new_running_mean_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_mean])
|
||||
momentum_v = graph_builder.value(input_scale.dtype, self.attrs['momentum'], input_scale.data_format)
|
||||
current_mean_tmp = graph_builder.emit('Mul', [momentum_v, mean_muls])
|
||||
updated_moving_mean = graph_builder.emit('Add', [new_running_mean_tmp, current_mean_tmp])
|
||||
mean_res = graph_builder.emit(
|
||||
'InplaceAssign', [input_mean, updated_moving_mean, updated_moving_mean], attrs={'fake_output': True})
|
||||
|
||||
# variance_res is calculated by sample variance, and need to multiply by num / (num - 1)
|
||||
var_num = float(num) / (num - 1)
|
||||
var_num_v = graph_builder.value(input_scale.dtype, var_num, input_scale.data_format)
|
||||
var_mul_update = graph_builder.emit('Mul', [var_num_v, var_mul])
|
||||
new_running_var_tmp = graph_builder.emit('Mul', [momentum_v_sub, input_variance])
|
||||
current_var_tmp = graph_builder.emit('Mul', [momentum_v, var_mul_update])
|
||||
updated_moving_variance = graph_builder.emit('Add', [new_running_var_tmp, current_var_tmp])
|
||||
variance_res = graph_builder.emit(
|
||||
'InplaceAssign', [input_variance, updated_moving_variance, updated_moving_variance],
|
||||
attrs={'fake_output': True})
|
||||
|
||||
# compute reverse, just return a C shape tensor
|
||||
reserve = graph_builder.emit('Add', [input_offset, scalar_one_v])
|
||||
return res_y, mean_res, variance_res, mean_muls, y_sqrt_rec, reserve
|
||||
# infer mode
|
||||
if not input_x.data_format == "NHWC":
|
||||
input_mean = graph_builder.emit('ExpandDims', [input_mean], attrs={'axis': 1})
|
||||
input_mean = graph_builder.emit('ExpandDims', [input_mean], attrs={'axis': 2})
|
||||
input_scale = graph_builder.emit('ExpandDims', [input_scale], attrs={'axis': 1})
|
||||
input_scale = graph_builder.emit('ExpandDims', [input_scale], attrs={'axis': 2})
|
||||
input_offset = graph_builder.emit('ExpandDims', [input_offset], attrs={'axis': 1})
|
||||
input_offset = graph_builder.emit('ExpandDims', [input_offset], attrs={'axis': 2})
|
||||
x_sub = graph_builder.emit('Sub', [input_x, input_mean])
|
||||
x_sub_mul = graph_builder.emit('Mul', [input_scale, x_sub])
|
||||
var_add = graph_builder.emit('Add', [epsilon_v, input_variance])
|
||||
var_add_sqrt = graph_builder.emit('Sqrt', [var_add])
|
||||
if not input_x.data_format == "NHWC":
|
||||
var_add_sqrt = graph_builder.emit('ExpandDims', [var_add_sqrt], attrs={'axis': 1})
|
||||
var_add_sqrt = graph_builder.emit('ExpandDims', [var_add_sqrt], attrs={'axis': 2})
|
||||
x_div = graph_builder.emit('RealDiv', [x_sub_mul, var_add_sqrt])
|
||||
res_y = graph_builder.emit('Add', [input_offset, x_div])
|
||||
return res_y, var_add, var_add, var_add, var_add
|
@ -0,0 +1,102 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for BatchNormGrad"""
|
||||
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
@VLD.add_format(DF.NHWC, DF.NHWC, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
||||
@VLD.add_format(DF.NCHW, DF.NCHW, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
||||
@VLD.add_format(DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT, DF.DEFAULT)
|
||||
@VLD.check_attrs('is_training', 'epsilon')
|
||||
class BatchNormGrad(Expander):
|
||||
"""BatchNormGrad expander"""
|
||||
def _expand(self, graph_builder):
|
||||
# get op info
|
||||
input_dy = self.inputs[0]
|
||||
input_x = self.inputs[1]
|
||||
input_scale = self.inputs[2]
|
||||
input_save_mean = self.inputs[3]
|
||||
input_save_inv_variance = self.inputs[4]
|
||||
|
||||
reduce_axis = ()
|
||||
shape_x = input_x.shape
|
||||
if input_x.data_format == "NHWC":
|
||||
reduce_axis = (0, 1, 2)
|
||||
num = shape_x[0] * shape_x[1] * shape_x[2]
|
||||
else:
|
||||
reduce_axis = (0, 2, 3)
|
||||
num = shape_x[0] * shape_x[2] * shape_x[3]
|
||||
ori_type = input_x.dtype
|
||||
if ori_type == 'float16':
|
||||
input_x = graph_builder.emit('Cast', [input_x], attrs={'dst_type': 'float32'})
|
||||
if input_dy.dtype == 'float16':
|
||||
input_dy = graph_builder.emit('Cast', [input_dy], attrs={'dst_type': 'float32'})
|
||||
num_rec = -1.0 / num
|
||||
num_rec_v = graph_builder.value(input_scale.dtype, num_rec, input_scale.data_format)
|
||||
dbeta = graph_builder.emit('ReduceSum', [input_dy], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
||||
|
||||
# in training input_save_inv_variance means 1 / sqrt(variance + epsilon), which is calculated in forward pass
|
||||
if self.attrs['is_training']:
|
||||
inv_variance = input_save_inv_variance
|
||||
else:
|
||||
epsilon_v = graph_builder.value(input_scale.dtype, self.attrs['epsilon'], input_scale.data_format)
|
||||
var_add = graph_builder.emit('Add', [input_save_inv_variance, epsilon_v])
|
||||
sqrt_var_eps = graph_builder.emit('Sqrt', [var_add])
|
||||
scalar_one = 1.0
|
||||
scalar_one_v = graph_builder.value(input_scale.dtype, scalar_one, input_scale.data_format)
|
||||
inv_variance = graph_builder.emit('RealDiv', [scalar_one_v, sqrt_var_eps])
|
||||
|
||||
# compute dgamma
|
||||
if not input_x.data_format == "NHWC":
|
||||
input_save_mean = graph_builder.emit('ExpandDims', [input_save_mean], attrs={'axis': 1})
|
||||
input_save_mean = graph_builder.emit('ExpandDims', [input_save_mean], attrs={'axis': 2})
|
||||
inv_variance = graph_builder.emit('ExpandDims', [inv_variance], attrs={'axis': 1})
|
||||
inv_variance = graph_builder.emit('ExpandDims', [inv_variance], attrs={'axis': 2})
|
||||
input_scale = graph_builder.emit('ExpandDims', [input_scale], attrs={'axis': 1})
|
||||
input_scale = graph_builder.emit('ExpandDims', [input_scale], attrs={'axis': 2})
|
||||
x_sub_mean = graph_builder.emit('Sub', [input_x, input_save_mean])
|
||||
x_div = graph_builder.emit('Mul', [x_sub_mean, inv_variance])
|
||||
dgamma_param = graph_builder.emit('Mul', [input_dy, x_div])
|
||||
dgamma = graph_builder.emit(
|
||||
'ReduceSum', [dgamma_param], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
|
||||
|
||||
# compute dx
|
||||
if self.attrs['is_training']:
|
||||
tmp_b = graph_builder.emit('Mul', [num_rec_v, dbeta])
|
||||
if not input_x.data_format == "NHWC":
|
||||
dgamma_expand = graph_builder.emit('ExpandDims', [dgamma], attrs={'axis': 1})
|
||||
dgamma_expand = graph_builder.emit('ExpandDims', [dgamma_expand], attrs={'axis': 2})
|
||||
tmp_b = graph_builder.emit('ExpandDims', [tmp_b], attrs={'axis': 1})
|
||||
tmp_b = graph_builder.emit('ExpandDims', [tmp_b], attrs={'axis': 2})
|
||||
else:
|
||||
dgamma_expand = dgamma
|
||||
x_sub_mean_dgamma_mul = graph_builder.emit('Mul', [x_div, dgamma_expand])
|
||||
tmp_c = graph_builder.emit('Mul', [num_rec_v, x_sub_mean_dgamma_mul])
|
||||
tmp_ab_add = graph_builder.emit('Add', [input_dy, tmp_b])
|
||||
tmp_abc_add = graph_builder.emit('Add', [tmp_ab_add, tmp_c])
|
||||
gamma_mul = graph_builder.emit('Mul', [input_scale, tmp_abc_add])
|
||||
dx = graph_builder.emit('Mul', [inv_variance, gamma_mul])
|
||||
else:
|
||||
y_scale = graph_builder.emit('Mul', [input_scale, input_dy])
|
||||
dx = graph_builder.emit('Mul', [inv_variance, y_scale])
|
||||
if ori_type == 'float16':
|
||||
dx = graph_builder.emit('Cast', [dx], attrs={'dst_type': 'float16'})
|
||||
|
||||
# set output tensors' data_format
|
||||
dx.data_format = self.outputs[0]['format']
|
||||
dgamma.data_format = self.outputs[1]['format']
|
||||
dbeta.data_format = self.outputs[2]['format']
|
||||
|
||||
return dx, dgamma, dbeta
|
@ -0,0 +1,30 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for relu"""
|
||||
from ._utils import Expander
|
||||
|
||||
|
||||
class ReLU(Expander):
|
||||
"""ReLU expander"""
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x = self.inputs[0]
|
||||
|
||||
const_zero = graph_builder.value(input_x.dtype, 0)
|
||||
ge_result = graph_builder.emit('Greater', [input_x, const_zero])
|
||||
ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
|
||||
result = graph_builder.emit('Mul', [ge_result, input_x])
|
||||
|
||||
return result
|
@ -0,0 +1,32 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ===========================================================================
|
||||
"""generate json desc for relu_grad"""
|
||||
from ._utils import Expander, ExpanderInfoValidator as VLD
|
||||
|
||||
|
||||
@VLD.check_all_formats_same
|
||||
class ReluGrad(Expander):
|
||||
"""ReLU expander"""
|
||||
|
||||
def _expand(self, graph_builder):
|
||||
input_x = self.inputs[0]
|
||||
input_y = self.inputs[1]
|
||||
|
||||
const_zero = graph_builder.value(input_y.dtype, 0)
|
||||
ge_result = graph_builder.emit('Greater', [input_y, const_zero])
|
||||
ge_result = graph_builder.emit('Cast', [ge_result], attrs={'dst_type': input_x.dtype})
|
||||
result = graph_builder.emit('Mul', [ge_result, input_x])
|
||||
|
||||
return result
|
@ -0,0 +1,84 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, input_scale, input_bias, input_mean, input_variance, is_training):
|
||||
super(Net, self).__init__()
|
||||
self.fused_bn_ex = P.BatchNorm(is_training=is_training, epsilon=1e-5, momentum=0.9)
|
||||
self.scale = Parameter(input_scale, name='scale')
|
||||
self.bias = Parameter(input_bias, name='b')
|
||||
self.mean = Parameter(input_mean, name='mean')
|
||||
self.variance = Parameter(input_variance, name='variance')
|
||||
def construct(self, input_x):
|
||||
return self.fused_bn_ex(input_x, self.scale, self.bias, self.mean, self.variance)
|
||||
|
||||
|
||||
def get_output(x, weight, bias, moving_mean, moving_var, is_training, enable_graph_kernel=False):
|
||||
if enable_graph_kernel:
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
net = Net(Tensor(weight), Tensor(bias), Tensor(moving_mean), Tensor(moving_var), is_training)
|
||||
output = net(Tensor(x))
|
||||
return output, net.mean, net.variance
|
||||
|
||||
|
||||
def test_bn_train():
|
||||
x = np.random.normal(0, 1, [1, 2, 4, 4]).astype(np.float32)
|
||||
weight = np.random.normal(0, 1, [2,]).astype(np.float32)
|
||||
bias = np.random.normal(0, 1, [2,]).astype(np.float32)
|
||||
moving_mean = np.random.normal(0, 1, [2,]).astype(np.float32)
|
||||
moving_var = np.random.normal(0, 1, [2,]).astype(np.float32)
|
||||
|
||||
train_expect = get_output(x, weight, bias, moving_mean, moving_var, True, False)
|
||||
train_output = get_output(x, weight, bias, moving_mean, moving_var, True, True)
|
||||
|
||||
assert np.allclose(train_expect[0][0].asnumpy(), train_output[0][0].asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(train_expect[0][3].asnumpy(), train_output[0][3].asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(train_expect[0][4].asnumpy(), train_output[0][4].asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(train_expect[1].data.asnumpy(), train_output[1].data.asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(train_expect[2].data.asnumpy(), train_output[2].data.asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
def test_bn_infer():
|
||||
x = np.random.normal(5, 1, [1, 2, 4, 4]).astype(np.float32)
|
||||
weight = np.random.normal(5, 1, [2,]).astype(np.float32)
|
||||
bias = np.random.normal(5, 1, [2,]).astype(np.float32)
|
||||
moving_mean = np.random.normal(5, 1, [2,]).astype(np.float32)
|
||||
moving_var = np.random.normal(5, 1, [2,]).astype(np.float32)
|
||||
|
||||
infer_expect = get_output(x, weight, bias, moving_mean, moving_var, False, False)
|
||||
infer_output = get_output(x, weight, bias, moving_mean, moving_var, False, True)
|
||||
|
||||
assert np.allclose(infer_expect[0][0].asnumpy(), infer_output[0][0].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_bn_train_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_bn_train()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_bn_infer_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_bn_infer()
|
@ -0,0 +1,87 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, is_training):
|
||||
super(Net, self).__init__()
|
||||
self.fused_bn_grad_ex = G.BatchNormGrad(is_training=is_training, epsilon=1e-5)
|
||||
|
||||
def construct(self, input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse):
|
||||
return self.fused_bn_grad_ex(
|
||||
input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse)
|
||||
|
||||
|
||||
def get_output(input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse,
|
||||
is_training, enable_graph_kernel=False):
|
||||
if enable_graph_kernel:
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
net = Net(is_training)
|
||||
output = net(input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse)
|
||||
return output
|
||||
|
||||
def test_bn_grad_train():
|
||||
input_dy = Tensor(np.random.normal(5, 1, [1, 2, 4, 4]).astype(np.float32))
|
||||
input_x = Tensor(np.random.normal(5, 1, [1, 2, 4, 4]).astype(np.float32))
|
||||
input_scale = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32))
|
||||
input_save_mean = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32))
|
||||
input_save_inv_variance = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32))
|
||||
input_reverse = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32))
|
||||
|
||||
expect = get_output(
|
||||
input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse, True, False)
|
||||
output = get_output(
|
||||
input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse, True, True)
|
||||
|
||||
assert np.allclose(expect[0].asnumpy(), output[0].asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(expect[1].asnumpy(), output[1].asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(expect[2].asnumpy(), output[2].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
def test_bn_grad_infer():
|
||||
input_dy = Tensor(np.random.normal(5, 1, [1, 2, 4, 4]).astype(np.float32))
|
||||
input_x = Tensor(np.random.normal(5, 1, [1, 2, 4, 4]).astype(np.float32))
|
||||
input_scale = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32))
|
||||
input_save_mean = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32))
|
||||
input_save_inv_variance = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32))
|
||||
input_reverse = Tensor(np.random.normal(5, 1, [2,]).astype(np.float32))
|
||||
|
||||
expect = get_output(
|
||||
input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse, False, False)
|
||||
output = get_output(
|
||||
input_dy, input_x, input_scale, input_save_mean, input_save_inv_variance, input_reverse, False, True)
|
||||
|
||||
assert np.allclose(expect[0].asnumpy(), output[0].asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(expect[1].asnumpy(), output[1].asnumpy(), 0.0001, 0.0001)
|
||||
assert np.allclose(expect[2].asnumpy(), output[2].asnumpy(), 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_bn_grad_train_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_bn_grad_train()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_bn_grad_infer_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_bn_grad_train()
|
@ -0,0 +1,61 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.relu = P.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
return self.relu(x)
|
||||
|
||||
|
||||
def get_output(x, enable_graph_kernel=False):
|
||||
if enable_graph_kernel:
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
net = Net()
|
||||
output = net(x)
|
||||
return output
|
||||
|
||||
|
||||
def test_relu(shape, dtype):
|
||||
x = Tensor(np.random.normal(0, 10, shape).astype(dtype))
|
||||
expect = get_output(x, False)
|
||||
output = get_output(x, True)
|
||||
|
||||
expect_np = expect.asnumpy().copy()
|
||||
output_np = output.asnumpy().copy()
|
||||
|
||||
assert np.allclose(expect_np, output_np, 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_relu_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_relu((4, 3), np.int32)
|
||||
test_relu((12, 1), np.float16)
|
||||
|
||||
def test_relu_ascend():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_relu((4, 3), np.int32)
|
||||
test_relu((12, 1), np.float16)
|
@ -0,0 +1,62 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.relu_grad = G.ReluGrad()
|
||||
|
||||
def construct(self, y_backprop, x):
|
||||
return self.relu_grad(y_backprop, x)
|
||||
|
||||
|
||||
def get_output(y_backprop, x, enable_graph_kernel=False):
|
||||
if enable_graph_kernel:
|
||||
context.set_context(enable_graph_kernel=True)
|
||||
net = Net()
|
||||
output = net(y_backprop, x)
|
||||
return output
|
||||
|
||||
|
||||
def test_relu_grad(shape1, shape2, dtype):
|
||||
x = Tensor(np.random.normal(0, 10, shape1).astype(dtype))
|
||||
y_backprop = Tensor(np.random.normal(0, 10, shape2).astype(dtype))
|
||||
expect = get_output(y_backprop, x, False)
|
||||
output = get_output(y_backprop, x, True)
|
||||
|
||||
expect_np = expect.asnumpy().copy()
|
||||
output_np = output.asnumpy().copy()
|
||||
|
||||
assert np.allclose(expect_np, output_np, 0.0001, 0.0001)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_relu_grad_gpu():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
test_relu_grad((4, 3), (4, 3), np.int32)
|
||||
test_relu_grad((12, 1), (12, 1), np.float16)
|
||||
|
||||
def test_relu_grad_ascend():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
test_relu_grad((4, 3), (4, 3), np.int32)
|
||||
test_relu_grad((12, 1), (12, 1), np.float16)
|
Loading…
Reference in new issue