From bd13f9ba339e2cdffdaf9a2d192a6aff5ca92e85 Mon Sep 17 00:00:00 2001 From: chang zherui <760161589@qq.com> Date: Thu, 23 Apr 2020 15:17:06 +0800 Subject: [PATCH] modify ResizeNearestNeighborV2D --- .../gpu/nn/fused_batch_norm_gpu_kernel.cc | 2 - .../gpu/nn/fused_batch_norm_gpu_kernel.h | 3 - mindspore/ccsrc/transform/op_declare.cc | 59 +++++++++---------- mindspore/ops/_grad/grad_nn_ops.py | 4 +- mindspore/ops/operations/_grad_ops.py | 4 +- mindspore/ops/operations/nn_ops.py | 8 +-- tests/ut/python/ops/test_ops.py | 2 +- 7 files changed, 35 insertions(+), 47 deletions(-) diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc index 4ddc710a4c..91747d24d8 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc @@ -55,7 +55,6 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm, .AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), FusedBatchNormGpuKernel, float) MS_REG_GPU_KERNEL_ONE(BatchNorm, @@ -69,7 +68,6 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm, .AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16), FusedBatchNormGpuKernel, half) } // namespace kernel diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h index 6f0c59e29a..5ca85f8e63 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h @@ -156,9 +156,6 @@ class FusedBatchNormGpuKernel : public GpuKernel { output_size_list_.push_back(para_size); // running variance output_size_list_.push_back(para_size); // save mean output_size_list_.push_back(para_size); // save variance - if (!is_train_) { - output_size_list_.push_back(para_size); // reserve - } return; } diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index d6ca3f4cbe..299ac4f44d 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -154,14 +154,14 @@ ATTR_MAP(BatchNorm) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}, {"epsilon", ATTR_DESC(epsilon, AnyTraits())}, {"is_training", ATTR_DESC(is_training, AnyTraits())}}; @@ -266,11 +266,6 @@ INPUT_MAP(GatherV2) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_D ATTR_MAP(GatherV2) = EMPTY_ATTR_MAP; OUTPUT_MAP(GatherV2) = {{0, OUTPUT_DESC(y)}}; -// ReduceSum -INPUT_MAP(ReduceSum) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}}; -ATTR_MAP(ReduceSum) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; -OUTPUT_MAP(ReduceSum) = {{0, OUTPUT_DESC(y)}}; - // ReduceSumD INPUT_MAP(ReduceSumD) = {{1, INPUT_DESC(x)}}; INPUT_ATTR_MAP(ReduceSumD) = { @@ -451,17 +446,17 @@ INPUT_MAP(Iou) = {{1, INPUT_DESC(bboxes)}, {2, INPUT_DESC(gtboxes)}}; ATTR_MAP(Iou) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; OUTPUT_MAP(Iou) = {{0, OUTPUT_DESC(overlap)}}; -// ResizeNearestNeighborD -INPUT_MAP(ResizeNearestNeighborD) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(ResizeNearestNeighborD) = { +// ResizeNearestNeighborV2D +INPUT_MAP(ResizeNearestNeighborV2D) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ResizeNearestNeighborV2D) = { {"size", ATTR_DESC(size, AnyTraits>(), AnyTraits>())}, {"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; -OUTPUT_MAP(ResizeNearestNeighborD) = {{0, OUTPUT_DESC(y)}}; +OUTPUT_MAP(ResizeNearestNeighborV2D) = {{0, OUTPUT_DESC(y)}}; -// ResizeNearestNeighborGrad -INPUT_MAP(ResizeNearestNeighborGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(size)}}; -ATTR_MAP(ResizeNearestNeighborGrad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; -OUTPUT_MAP(ResizeNearestNeighborGrad) = {{0, OUTPUT_DESC(y)}}; +// ResizeNearestNeighborV2Grad +INPUT_MAP(ResizeNearestNeighborV2Grad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(size)}}; +ATTR_MAP(ResizeNearestNeighborV2Grad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; +OUTPUT_MAP(ResizeNearestNeighborV2Grad) = {{0, OUTPUT_DESC(y)}}; // ApplyAdam INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)}, @@ -486,17 +481,17 @@ INPUT_MAP(Relu6Grad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}}; ATTR_MAP(Relu6Grad) = EMPTY_ATTR_MAP; OUTPUT_MAP(Relu6Grad) = {{0, OUTPUT_DESC(backprops)}}; -// ResizeBilinearGrad -INPUT_MAP(ResizeBilinearGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}}; -ATTR_MAP(ResizeBilinearGrad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; -OUTPUT_MAP(ResizeBilinearGrad) = {{0, OUTPUT_DESC(y)}}; +// ResizeBilinearV2Grad +INPUT_MAP(ResizeBilinearV2Grad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}}; +ATTR_MAP(ResizeBilinearV2Grad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; +OUTPUT_MAP(ResizeBilinearV2Grad) = {{0, OUTPUT_DESC(y)}}; -// ResizeBilinearD -INPUT_MAP(ResizeBilinearD) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(ResizeBilinearD) = { +// ResizeBilinearV2D +INPUT_MAP(ResizeBilinearV2D) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ResizeBilinearV2D) = { {"size", ATTR_DESC(size, AnyTraits>(), AnyTraits>())}, {"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; -OUTPUT_MAP(ResizeBilinearD) = {{0, OUTPUT_DESC(y)}}; +OUTPUT_MAP(ResizeBilinearV2D) = {{0, OUTPUT_DESC(y)}}; // ZerosLike INPUT_MAP(ZerosLike) = {{1, INPUT_DESC(x)}}; @@ -609,10 +604,12 @@ ATTR_MAP(ArgMinWithValue) = {{"axis", ATTR_DESC(dimension, AnyTraits())}, {"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; OUTPUT_MAP(ArgMinWithValue) = {{0, OUTPUT_DESC(indice)}, {1, OUTPUT_DESC(values)}}; -// ReduceAll -INPUT_MAP(ReduceAll) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}}; -ATTR_MAP(ReduceAll) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; -OUTPUT_MAP(ReduceAll) = {{0, OUTPUT_DESC(y)}} +// ReduceAllD +INPUT_MAP(ReduceAllD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ReduceAllD) = { + {2, ATTR_DESC(axis, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ReduceAllD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ReduceAllD) = {{0, OUTPUT_DESC(y)}}; // ReduceMeanD INPUT_MAP(ReduceMeanD) = {{1, INPUT_DESC(x)}}; diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index e43d3d5d3a..6db059a7bb 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -356,12 +356,10 @@ def get_bprop_batch_norm(self): if is_training: saved_reserve_1 = out[3] saved_reserve_2 = out[4] - saved_reserve_3 = out[5] else: saved_reserve_1 = mean saved_reserve_2 = variance - saved_reserve_3 = variance - out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2, saved_reserve_3) + out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2) dx = out[0] dscale = out[1] dbias = out[2] diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index c29832dcb7..9f277908ed 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -69,11 +69,11 @@ class BatchNormGrad(PrimitiveWithInfer): self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT) self.add_prim_attr('data_format', "NCHW") - def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape, reserve_3_shape): + def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape): validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape) - def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type, reserve_3_type): + def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type): return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 49145fb072..93359c7dd9 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -537,7 +537,6 @@ class BatchNorm(PrimitiveWithInfer): - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`. - **reserve_space_1** (Tensor) - Tensor of shape :math:`(C,)`. - **reserve_space_2** (Tensor) - Tensor of shape :math:`(C,)`. - - **reserve_space_3** (Tensor) - Tensor of shape :math:`(C,)`. """ @prim_attr_register @@ -546,8 +545,7 @@ class BatchNorm(PrimitiveWithInfer): validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.add_prim_attr('data_format', "NCHW") self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], - outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2', - 'reserve_space_3']) + outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) def infer_shape(self, input_x, scale, bias, mean, variance): validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) @@ -557,7 +555,7 @@ class BatchNorm(PrimitiveWithInfer): validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) - return (input_x, scale, scale, scale, scale, scale) + return (input_x, scale, scale, scale, scale) def infer_dtype(self, input_x, scale, bias, mean, variance): validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name) @@ -570,7 +568,7 @@ class BatchNorm(PrimitiveWithInfer): else: args_moving = {"mean": mean, "variance": variance} validator.check_tensor_type_same(args_moving, [mstype.float16, mstype.float32], self.name) - return (input_x, scale, bias, input_x, input_x, input_x) + return (input_x, scale, bias, input_x, input_x) class Conv2D(PrimitiveWithInfer): diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 8b14ea2366..1dea7b6502 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -671,7 +671,7 @@ test_case_nn_ops = [ 'skip': []}), ('BatchNormGrad', { 'block': G.BatchNormGrad(), - 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64], [64]], + 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], 'skip': ['backward']}), ('ApplyMomentum', {