From 6a58479e422e325c27ca6f68c23a38adbf510d91 Mon Sep 17 00:00:00 2001 From: TFBunny Date: Wed, 20 Jan 2021 09:35:27 -0500 Subject: [PATCH] add dynamic shape and testcases to GPU biasadd --- .../gpu/{math => nn}/bias_add_gpu_kernel.cc | 2 +- .../gpu/{math => nn}/bias_add_gpu_kernel.h | 27 +++++--- mindspore/core/abstract/infer_functions.h | 2 + mindspore/core/abstract/prim_nn.cc | 35 ++++++++++ .../core/abstract/primitive_infer_map.cc | 1 + mindspore/ops/operations/nn_ops.py | 21 +++--- tests/st/ops/gpu/test_dense_op.py | 65 ++++++++++++++++++- 7 files changed, 131 insertions(+), 22 deletions(-) rename mindspore/ccsrc/backend/kernel_compiler/gpu/{math => nn}/bias_add_gpu_kernel.cc (94%) rename mindspore/ccsrc/backend/kernel_compiler/gpu/{math => nn}/bias_add_gpu_kernel.h (91%) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_gpu_kernel.cc similarity index 94% rename from mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.cc rename to mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_gpu_kernel.cc index a07fb6ddf6..242cfa1e96 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_gpu_kernel.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.h" +#include "backend/kernel_compiler/gpu/nn/bias_add_gpu_kernel.h" namespace mindspore { namespace kernel { diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_gpu_kernel.h similarity index 91% rename from mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.h rename to mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_gpu_kernel.h index 00d7be7037..a9b6534517 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/bias_add_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/bias_add_gpu_kernel.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_BIAS_ADD_GPU_KERNEL_H -#define MINDSPORE_BIAS_ADD_GPU_KERNEL_H +#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GPU_KERNEL_H_ #include #include #include @@ -30,13 +30,7 @@ namespace kernel { template class BiasAddGpuKernel : public GpuKernel { public: - BiasAddGpuKernel() - : cudnn_handle_(nullptr), - cudnn_data_type_(CUDNN_DATA_FLOAT), - x_desc_(nullptr), - b_desc_(nullptr), - op_desc_(nullptr), - is_null_input_(false) {} + BiasAddGpuKernel() { ResetResource(); } ~BiasAddGpuKernel() override { DestroyResource(); } const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -117,6 +111,18 @@ class BiasAddGpuKernel : public GpuKernel { return true; } + void ResetResource() noexcept override { + cudnn_handle_ = nullptr; + cudnn_data_type_ = CUDNN_DATA_FLOAT; + x_desc_ = nullptr; + b_desc_ = nullptr; + op_desc_ = nullptr; + is_null_input_ = false; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); + } + void DestroyResource() noexcept override { CHECK_CUDNN_RET_WITH_ERROR(kernel_node_, cudnnDestroyOpTensorDescriptor(op_desc_), "cudnnDestroyTensorDescriptor failed"); @@ -136,6 +142,7 @@ class BiasAddGpuKernel : public GpuKernel { CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnCreateOpTensorDescriptor(&op_desc_), "cudnnCreateOpTensorDescriptor failed"); } + void InitSizeLists() override { size_t x_size, b_size; CHECK_CUDNN_RET_WITH_EXCEPT(kernel_node_, cudnnGetTensorSizeInBytes(x_desc_, &x_size), @@ -161,4 +168,4 @@ class BiasAddGpuKernel : public GpuKernel { } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_BIAS_ADD_GPU_KERNEL_H +#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_BIAS_ADD_GPU_KERNEL_H_ diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 25abf5e75d..95e451acb3 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -63,6 +63,8 @@ AbstractBasePtr InferImplConv2DBackpropInput(const AnalysisEnginePtr &, const Pr const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplGelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index 4aad2b8f3a..2e31a30983 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -470,6 +470,41 @@ AbstractBasePtr InferImplConv2DBackpropFilter(const AnalysisEnginePtr &, const P return args_spec_list[2]->Broaden(); } +AbstractBasePtr InferImplBiasAdd(const AnalysisEnginePtr &, const PrimitivePtr &primitive, + const AbstractBasePtrList &args_spec_list) { + const std::string op_name = primitive->name(); + CheckArgsSize(op_name, args_spec_list, 2); + auto x = CheckArg(op_name, args_spec_list, 0); + auto bias = CheckArg(op_name, args_spec_list, 1); + MS_EXCEPTION_IF_NULL(x); + MS_EXCEPTION_IF_NULL(x->shape()); + ShapeVector x_shape = x->shape()->shape(); + MS_EXCEPTION_IF_NULL(bias); + MS_EXCEPTION_IF_NULL(bias->shape()); + ShapeVector bias_shape = bias->shape()->shape(); + ShapeVector x_min_shape = x->shape()->min_shape(); + ShapeVector x_max_shape = x->shape()->max_shape(); + std::set available_data_format{"NCHW", "NHWC"}; + auto data_format_ptr = primitive->GetAttr("data_format"); + std::string data_format = "NCHW"; + if ((data_format_ptr != nullptr) && data_format_ptr->isa()) { + data_format = data_format_ptr->cast()->value(); + } + if (available_data_format.find(data_format) == available_data_format.end()) { + MS_LOG(EXCEPTION) << "Unsupported data format: " << data_format << ", use NCHW or NHWC."; + } + auto x_channel = data_format == "NHWC" ? x_shape[x_shape.size() - 1] : x_shape[1]; + // Additional check for dynamic shape + // Last infer will be real shape values + bool x_not_dyn = std::all_of(x_shape.begin(), x_shape.end(), [](int64_t value) { return value != Shape::SHP_ANY; }); + if (x_not_dyn && bias_shape[0] != x_channel) { + MS_LOG(EXCEPTION) << "BiasAdd shape error, data format is " << data_format + << ", got bias_shape[0]: " << bias_shape[0] << ", x_channel: " << x_channel << "."; + } + (void)CheckMinMaxShape(x_shape, &x_min_shape, &x_max_shape); + return std::make_shared(x->element(), std::make_shared(x_shape, x_min_shape, x_max_shape)); +} + AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list) { // Inputs: at least one tensor(y_backprop) diff --git a/mindspore/core/abstract/primitive_infer_map.cc b/mindspore/core/abstract/primitive_infer_map.cc index 0af8578033..cab15ef68f 100644 --- a/mindspore/core/abstract/primitive_infer_map.cc +++ b/mindspore/core/abstract/primitive_infer_map.cc @@ -114,6 +114,7 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() { {prim::kPrimConv2D, {InferImplConv2D, true}}, {prim::kPrimConv2DBackpropInput, {InferImplConv2DBackpropInput, true}}, {prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}}, + {prim::kPrimBiasAdd, {InferImplBiasAdd, true}}, {prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}}, {prim::kPrimRelu, {InferImplRelu, true}}, {prim::kPrimZerosLike, {InferImplZerosLike, true}}, diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 5f524aa79b..1aafdd0af3 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1887,19 +1887,21 @@ class Conv2DBackpropInput(PrimitiveWithInfer): return out -class BiasAdd(PrimitiveWithInfer): +class BiasAdd(PrimitiveWithCheck): r""" Returns sum of input and bias tensor. Adds the 1-D bias tensor to the input tensor, and broadcasts the shape on all axis except for the channel axis. + Args: + data_format (str): The format of input and output data. It should be 'NHWC' or 'NCHW', + default is 'NCHW'. + Inputs: - **input_x** (Tensor) - The input tensor. The shape can be 2-4 dimensions. - - **bias** (Tensor) - The bias tensor, with shape :math:`(C)`. - - **data_format** (str) - The format of input and output data. It should be 'NHWC' or 'NCHW',\ - default is 'NCHW'. - The shape of `bias` must be the same as `input_x` in the second dimension. + - **bias** (Tensor) - The bias tensor, with shape :math:`(C)`. The shape of + `bias` must be the same as `input_x`'s channel dimension. Outputs: Tensor, with the same shape and type as `input_x`. @@ -1924,17 +1926,16 @@ class BiasAdd(PrimitiveWithInfer): raise ValueError("NHWC format only support in GPU target.") self.add_prim_attr('data_format', self.format) - def infer_shape(self, x_shape, b_shape): + def check_shape(self, x_shape, b_shape): validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name) validator.check_equal_int(len(b_shape), 1, "bias rank", self.name) x_channel = x_shape[1] if self.format == "NCHW" else x_shape[-1] - validator.check("b_shape[0]", b_shape[0], "x_shape[1]", x_channel, Rel.EQ, self.name) - return x_shape + if np.all(np.array(x_shape) != -1): + validator.check("b_shape[0]", b_shape[0], "x_channel", x_channel, Rel.EQ, self.name) - def infer_dtype(self, x_type, b_type): + def check_dtype(self, x_type, b_type): args = {"input_x": x_type, "bias": b_type} validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) - return x_type class TopK(PrimitiveWithInfer): diff --git a/tests/st/ops/gpu/test_dense_op.py b/tests/st/ops/gpu/test_dense_op.py index 4c03bfc417..a8cba4c1db 100644 --- a/tests/st/ops/gpu/test_dense_op.py +++ b/tests/st/ops/gpu/test_dense_op.py @@ -23,7 +23,7 @@ from mindspore.common.parameter import ParameterTuple from mindspore.ops import composite as C from mindspore.ops import operations as P from mindspore.ops.composite import GradOperation - +from mindspore.ops.operations import _inner_ops as inner class BiasAdd(nn.Cell): def __init__(self): @@ -442,3 +442,66 @@ def test_biasadd_4d(): error = np.ones(shape=[3]) * 1.0e-6 assert np.all(diff < error) assert np.all(-diff < error) + + +class BiasAddDynamic(nn.Cell): + def __init__(self): + super(BiasAddDynamic, self).__init__() + self.ba = P.BiasAdd() + self.test_dynamic = inner.GpuConvertToDynamicShape() + + def construct(self, x, b): + x = self.test_dynamic(x) + output = self.ba(x, b) + return output + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_bias_add_dynamic_two_inputs(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = BiasAddDynamic() + + x_1 = Tensor(np.array([[0.1, 0.2, 0.3, 0.4], + [0.5, 0.6, 0.7, 0.8], + [0.9, 1.0, 1.1, 1.2]]).astype(np.float32)) + b_1 = Tensor(np.array([0.1, 0.2, 0.3, 0.4]).astype(np.float32)) + expect_1 = np.array([[0.2, 0.4, 0.6, 0.8], + [0.6, 0.8, 1.0, 1.2], + [1.0, 1.2, 1.4, 1.6]]) + error_1 = np.ones(shape=[3, 4]) * 1.0e-6 + result_1 = net(x_1, b_1) + diff_1 = result_1.asnumpy() - expect_1 + assert np.all(diff_1 < error_1) + assert np.all(-diff_1 < error_1) + + x_2 = Tensor(np.array([[[1, 2, 3, 4, 5, 6, 7, 8], + [9, 10, 11, 12, 13, 14, 15, 16], + [17, 18, 19, 20, 21, 22, 23, 24], + [25, 26, 27, 28, 29, 30, 31, 32]], + [[33, 34, 35, 36, 37, 38, 39, 40], + [41, 42, 43, 44, 45, 46, 47, 48], + [49, 50, 51, 52, 53, 54, 55, 56], + [57, 58, 59, 60, 61, 62, 63, 64]], + [[65, 66, 67, 68, 69, 70, 71, 72], + [73, 74, 75, 76, 77, 78, 79, 80], + [81, 82, 83, 84, 85, 86, 87, 88], + [89, 90, 91, 92, 93, 94, 95, 96]]]).astype(np.float32)) + b_2 = Tensor(np.array([1, 2, 3, 4]).astype(np.float32)) + expect_2 = np.array([[[2, 3, 4, 5, 6, 7, 8, 9], + [11, 12, 13, 14, 15, 16, 17, 18], + [20, 21, 22, 23, 24, 25, 26, 27], + [29, 30, 31, 32, 33, 34, 35, 36]], + [[34, 35, 36, 37, 38, 39, 40, 41], + [43, 44, 45, 46, 47, 48, 49, 50], + [52, 53, 54, 55, 56, 57, 58, 59], + [61, 62, 63, 64, 65, 66, 67, 68]], + [[66, 67, 68, 69, 70, 71, 72, 73], + [75, 76, 77, 78, 79, 80, 81, 82], + [84, 85, 86, 87, 88, 89, 90, 91], + [93, 94, 95, 96, 97, 98, 99, 100]]]) + error_2 = np.ones(shape=[3, 4, 8]) * 1.0e-6 + result_2 = net(x_2, b_2) + diff_2 = result_2.asnumpy() - expect_2 + assert np.all(diff_2 < error_2) + assert np.all(-diff_2 < error_2)