diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_ex_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_ex_gpu_kernel.h index ddba8b1c76..82cfe6ba93 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_ex_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/fused_batch_norm_ex_gpu_kernel.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,25 +29,7 @@ namespace kernel { template class FusedBatchNormExGpuKernel : public GpuKernel { public: - FusedBatchNormExGpuKernel() - : input_x_size_(0), - input_z_size_(0), - para_size_(0), - output_size_(0), - workspace_size_(0), - reserve_size_(0), - mode_(CUDNN_BATCHNORM_SPATIAL), - bn_ops_(CUDNN_BATCHNORM_OPS_BN), - epsilon_(10e-5), - exp_avg_factor_(0.1), - is_null_input_(false), - x_desc_(nullptr), - y_desc_(nullptr), - z_desc_(nullptr), - scale_bias_mean_var_desc_(nullptr), - activation_desc_(nullptr), - handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT) {} + FusedBatchNormExGpuKernel() { ResetResource(); } ~FusedBatchNormExGpuKernel() override { DestroyResource(); } const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -142,6 +124,30 @@ class FusedBatchNormExGpuKernel : public GpuKernel { return true; } + void ResetResource() noexcept override { + input_x_size_ = 0; + input_z_size_ = 0; + para_size_ = 0; + output_size_ = 0; + workspace_size_ = 0; + reserve_size_ = 0; + mode_ = CUDNN_BATCHNORM_SPATIAL; + bn_ops_ = CUDNN_BATCHNORM_OPS_BN; + epsilon_ = 10e-5; + exp_avg_factor_ = 0.1; + is_null_input_ = false; + x_desc_ = nullptr; + y_desc_ = nullptr; + z_desc_ = nullptr; + scale_bias_mean_var_desc_ = nullptr; + activation_desc_ = nullptr; + handle_ = nullptr; + cudnn_data_type_ = CUDNN_DATA_FLOAT; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + void DestroyResource() noexcept override { CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(x_desc_), "Destroy x desc failed"); CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyTensorDescriptor(y_desc_), "Destroy y desc failed"); diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 92d8670e76..e830fc511b 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -51,6 +51,8 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index c5c620946a..42294df5ab 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -23,6 +23,15 @@ namespace mindspore { namespace abstract { +int64_t GetAndCheckFormat(const ValuePtr &value) { + int64_t data_format; + bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format); + if (!result || (data_format != Format::NHWC && data_format != Format::NCHW)) { + MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW and NHWC"; + } + return data_format; +} + AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: a tensor. @@ -235,6 +244,54 @@ AbstractBasePtr InferImplFusedBatchNormGrad(const AnalysisEnginePtr &, const Pri return std::make_shared(rets); } +AbstractBasePtr InferImplFusedBatchNormEx(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + // Inputs: five tensors(x, gamma, beta, mean, variance). + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 5); + AbstractTensorPtr input_x = CheckArg(op_name, args_spec_list, 0); + MS_EXCEPTION_IF_NULL(input_x); + MS_EXCEPTION_IF_NULL(input_x->shape()); + ShapeVector x_shape = input_x->shape()->shape(); + ShapeVector x_min_shape = input_x->shape()->min_shape(); + ShapeVector x_max_shape = input_x->shape()->max_shape(); + CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); + if (x_shape.size() != 4) { + MS_LOG(EXCEPTION) << "Input rank should 4."; + } + auto data_format_ptr = primitive->GetAttr("format"); + MS_EXCEPTION_IF_NULL(data_format_ptr); + int64_t data_format = GetAndCheckFormat(data_format_ptr); + int64_t c_axis = 1; + if (data_format == Format::NHWC) { + c_axis = 3; + } + for (size_t i = 1; i < args_spec_list.size(); ++i) { + AbstractTensorPtr arg_spec = CheckArg(op_name, args_spec_list, i); + MS_EXCEPTION_IF_NULL(arg_spec); + MS_EXCEPTION_IF_NULL(arg_spec->shape()); + ShapeVector arg_shape = arg_spec->shape()->shape(); + if (arg_shape.size() != 1) { + MS_LOG(EXCEPTION) << "Arg " << i << " rank should be 1, but got " << arg_shape.size(); + } + if ((x_shape[c_axis] != Shape::SHP_ANY) && (arg_shape[0] != x_shape[c_axis])) { + MS_LOG(EXCEPTION) << "Arg " << i << " shape[0] should equal to x_shape[" << c_axis << "]=" << x_shape[c_axis] + << ", but got " << arg_shape[0]; + } + } + AbstractTensorPtr input_gamma = CheckArg(op_name, args_spec_list, 1); + ShapeVector gamma_shape = input_gamma->shape()->shape(); + ShapeVector gamma_min_shape = input_gamma->shape()->min_shape(); + ShapeVector gamma_max_shape = input_gamma->shape()->max_shape(); + CheckMinMaxShape(gamma_shape, &gamma_min_shape, &gamma_max_shape); + ShapePtr output_shape_ptr = std::make_shared(x_shape, x_min_shape, x_max_shape); + AbstractTensorPtr output = std::make_shared(input_x->element(), output_shape_ptr); + ShapePtr gamma_shape_ptr = std::make_shared(gamma_shape, gamma_min_shape, gamma_max_shape); + AbstractTensorPtr output_gamma = std::make_shared(input_gamma->element(), gamma_shape_ptr); + AbstractBasePtrList rets = {output, output_gamma, output_gamma, output_gamma, output_gamma, output_gamma}; + return std::make_shared(rets); +} + AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: five tensors(y_backprop, x, scale, save_mean, save_inv_variance). @@ -311,15 +368,6 @@ void Conv2DPadFunction(std::vector *output_hw, std::vector *pa } } -int64_t GetAndCheckFormat(const ValuePtr &value) { - int64_t data_format; - bool result = CheckAndConvertUtils::GetDataFormatEnumValue(value, &data_format); - if (!result || (data_format != Format::NHWC && data_format != Format::NCHW)) { - MS_LOG(EXCEPTION) << "data format is invalid, only support NCHW and NHWC"; - } - return data_format; -} - AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { const std::string op_name = primitive->name(); diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 5086ea7d21..94f7ca10c4 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -141,6 +141,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimFusedBatchNorm, {InferImplFusedBatchNorm, true}}, {prim::kPrimFusedSparseAdam, {InferImplFusedSparseAdam, true}}, {prim::kPrimFusedBatchNormGrad, {InferImplFusedBatchNormGrad, true}}, + {prim::kPrimFusedBatchNormEx, {InferImplFusedBatchNormEx, true}}, {prim::kPrimBatchNormGrad, {InferImplBatchNormGrad, true}}, {prim::kPrimReluGrad, {InferImplReluGrad, true}}, {prim::kPrimConv2D, {InferImplConv2D, true}}, diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 4cb35ddb29..0fa209b2ed 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -871,11 +871,12 @@ class FusedBatchNorm(Primitive): self.target = context.get_context("device_target") -class FusedBatchNormEx(PrimitiveWithInfer): +class FusedBatchNormEx(PrimitiveWithCheck): r""" FusedBatchNormEx is an extension of FusedBatchNorm, FusedBatchNormEx has one more output(output reserve) than FusedBatchNorm, reserve will be used in backpropagation phase. FusedBatchNorm is a BatchNorm that - moving mean and moving variance will be computed instead of being loaded. + moving mean and moving variance will be computed instead of being loaded. FusedBatchNormEx currently only + supports 4D inputs. Batch Normalization is widely used in convolutional networks. This operation applies Batch Normalization over input to avoid internal covariate shift as described in the @@ -899,7 +900,7 @@ class FusedBatchNormEx(PrimitiveWithInfer): Default: "NCHW". Inputs: - - **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C)`, + - **input_x** (Tensor) - The input of FusedBatchNormEx, Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`, data type: float16 or float32. - **scale** (Parameter) - Parameter scale, same with gamma above-mentioned, Tensor of shape :math:`(C,)`, data type: float32. @@ -970,25 +971,22 @@ class FusedBatchNormEx(PrimitiveWithInfer): raise ValueError("NHWC format only support in GPU target.") self.add_prim_attr('data_format', self.format) - def infer_shape(self, input_x, scale, bias, mean, variance): + def check_shape(self, input_x, scale, bias, mean, variance): input_shape_norm = input_x if self.format == "NCHW" else (input_x[0], input_x[3], input_x[1], input_x[2]) + validator.check_equal_int(len(input_shape_norm), 4, "x rank", self.name) validator.check_equal_int(len(scale), 1, "scale rank", self.name) validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) - validator.check("scale shape[0]", scale[0], "input channel", input_shape_norm[1], Rel.EQ, self.name) validator.check_equal_int(len(mean), 1, "mean rank", self.name) - validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) - return (input_x, scale, scale, scale, scale, scale) - def infer_dtype(self, input_x, scale, bias, mean, variance): + def check_dtype(self, input_x, scale, bias, mean, variance): validator.check_tensor_dtype_valid("input_x", input_x, [mstype.float16, mstype.float32], self.name) args = {"scale": scale, "bias": bias} validator.check_tensors_dtypes_same_and_valid(args, [mstype.float32], self.name) args_moving = {"mean": mean, "variance": variance} valid_dtypes = [mstype.tensor_type(mstype.float32)] validator.check_types_same_and_valid(args_moving, valid_dtypes, self.name) - return (input_x, scale, scale, scale, scale, scale) class InstanceNorm(PrimitiveWithInfer): diff --git a/tests/st/ops/gpu/test_fused_batchnorm_ex_op.py b/tests/st/ops/gpu/test_fused_batchnorm_ex_op.py new file mode 100644 index 0000000000..62c99fc401 --- /dev/null +++ b/tests/st/ops/gpu/test_fused_batchnorm_ex_op.py @@ -0,0 +1,128 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +from mindspore.common.tensor import Tensor +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer +from mindspore.nn import Cell +from mindspore.ops.operations import _inner_ops as inner +from mindspore.ops import operations as P + + +class NetFusedBatchNormEx(Cell): + def __init__(self, num_features, gamma_init, beta_init, mean_init, var_init, use_batch_statistics=None): + super(NetFusedBatchNormEx, self).__init__() + self.bn = P.FusedBatchNormEx(mode=1, epsilon=0.00001, momentum=0.1) + self.moving_mean = Parameter(initializer( + mean_init, num_features), name="mean", requires_grad=False) + self.moving_variance = Parameter(initializer( + var_init, num_features), name="variance", requires_grad=False) + self.gamma = Parameter(initializer( + gamma_init, num_features), name="gamma", requires_grad=True) + self.beta = Parameter(initializer( + beta_init, num_features), name="beta", requires_grad=True) + self.dynshape = inner.GpuConvertToDynamicShape() + + def construct(self, x): + x = self.bn(x, self.gamma, self.beta, self.moving_mean, self.moving_variance) + return x + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fused_bn_ex(): + x = np.array([[ + [[1, 3, 3, 5], [2, 4, 6, 8], [3, 6, 7, 7], [4, 3, 8, 2]], + [[5, 7, 6, 3], [3, 5, 6, 7], [9, 4, 2, 5], [7, 5, 8, 1]]]]).astype(np.float32) + expect_output = np.array([[[[-0.6059, 0.3118, 0.3118, 1.2294], + [-0.1471, 0.7706, 1.6882, 2.6059], + [0.3118, 1.6882, 2.1471, 2.1471], + [0.7706, 0.3118, 2.6059, -0.1471]], + + [[0.9119, 1.8518, 1.3819, -0.0281], + [-0.0281, 0.9119, 1.3819, 1.8518], + [2.7918, 0.4419, -0.4981, 0.9119], + [1.8518, 0.9119, 2.3218, -0.9680]]]]).astype(np.float32) + + weight = np.ones(2).astype(np.float32) + bias = np.ones(2).astype(np.float32) + moving_mean = np.ones(2).astype(np.float32) + moving_var = np.ones(2).astype(np.float32) + error = np.ones(shape=[1, 2, 4, 4]) * 1.0e-4 + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + bn_net = NetFusedBatchNormEx(2, Tensor(weight), Tensor(bias), Tensor(moving_mean), Tensor(moving_var)) + output_list = bn_net(Tensor(x)) + output = output_list[0] + diff = output.asnumpy() - expect_output + assert np.all(diff < error) + assert np.all(-diff < error) + + +class NetFusedBatchNormExDynamic(Cell): + def __init__(self, num_features, gamma_init, beta_init, mean_init, var_init, use_batch_statistics=None): + super(NetFusedBatchNormExDynamic, self).__init__() + self.bn = P.FusedBatchNormEx(mode=1, epsilon=0.00001, momentum=0.1) + self.moving_mean = Parameter(initializer( + mean_init, num_features), name="mean", requires_grad=False) + self.moving_variance = Parameter(initializer( + var_init, num_features), name="variance", requires_grad=False) + self.gamma = Parameter(initializer( + gamma_init, num_features), name="gamma", requires_grad=True) + self.beta = Parameter(initializer( + beta_init, num_features), name="beta", requires_grad=True) + self.dynshape = inner.GpuConvertToDynamicShape() + + def construct(self, x): + x = self.dynshape(x) + x = self.bn(x, self.gamma, self.beta, self.moving_mean, self.moving_variance) + return x + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_fused_bn_ex_dynamic(): + x = np.array([[ + [[1, 3, 3, 5], [2, 4, 6, 8], [3, 6, 7, 7], [4, 3, 8, 2]], + [[5, 7, 6, 3], [3, 5, 6, 7], [9, 4, 2, 5], [7, 5, 8, 1]]]]).astype(np.float32) + expect_output = np.array([[[[-0.6059, 0.3118, 0.3118, 1.2294], + [-0.1471, 0.7706, 1.6882, 2.6059], + [0.3118, 1.6882, 2.1471, 2.1471], + [0.7706, 0.3118, 2.6059, -0.1471]], + + [[0.9119, 1.8518, 1.3819, -0.0281], + [-0.0281, 0.9119, 1.3819, 1.8518], + [2.7918, 0.4419, -0.4981, 0.9119], + [1.8518, 0.9119, 2.3218, -0.9680]]]]).astype(np.float32) + + weight = np.ones(2).astype(np.float32) + bias = np.ones(2).astype(np.float32) + moving_mean = np.ones(2).astype(np.float32) + moving_var = np.ones(2).astype(np.float32) + error = np.ones(shape=[1, 2, 4, 4]) * 1.0e-4 + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + bn_net = NetFusedBatchNormExDynamic(2, Tensor(weight), Tensor(bias), Tensor(moving_mean), Tensor(moving_var)) + output_list = bn_net(Tensor(x)) + output = output_list[0] + diff = output.asnumpy() - expect_output + assert np.all(diff < error) + assert np.all(-diff < error)