diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index f8b3797832..d6b23cf386 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -354,12 +354,12 @@ class AvgPool1d(_PoolNd): kernel_size=1, stride=1, pad_mode="valid"): - super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode) validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name) validator.check_value_type('stride', stride, [int], self.cls_name) self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name) validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name) validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name) + super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode) self.kernel_size = (1, kernel_size) self.stride = (1, stride) self.avg_pool = P.AvgPool(ksize=self.kernel_size, diff --git a/mindspore/ops/_op_impl/tbe/conv2d.py b/mindspore/ops/_op_impl/tbe/conv2d.py index 5531b573c3..1773b5f110 100644 --- a/mindspore/ops/_op_impl/tbe/conv2d.py +++ b/mindspore/ops/_op_impl/tbe/conv2d.py @@ -27,7 +27,8 @@ conv2d_op_info = TBERegOp("Conv2D") \ .attr("stride", "required", "listInt", "all") \ .attr("pad_list", "required", "listInt", "all") \ .attr("dilation", "required", "listInt", "all") \ - .attr("offset_a", "optional", "int", "all") \ + .attr("groups", "optional", "int", "all") \ + .attr("data_format", "optional", "str", "all") \ .input(0, "x", False, "required", "all") \ .input(1, "filter", False, "required", "all") \ .input(2, "bias", False, "optional", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py b/mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py index cba8f1e980..76b5715d8e 100644 --- a/mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +++ b/mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py @@ -27,7 +27,7 @@ conv2d_backprop_input_op_info = TBERegOp("Conv2DBackpropInput") \ .attr("stride", "required", "listInt", "all") \ .attr("pad_list", "required", "listInt", "all") \ .attr("dilation", "required", "listInt", "all") \ - .attr("group", "optional", "int", "all") \ + .attr("groups", "optional", "int", "all") \ .attr("data_format", "optional", "str", "all") \ .input(0, "out_backprop", False, "required", "all") \ .input(1, "filter", False, "required", "all") \ diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index f0216de244..a50f04efff 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1043,6 +1043,7 @@ class Conv2D(PrimitiveWithInfer): self.add_prim_attr('data_format', self.format) self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) self.group = validator.check_positive_int(group, 'group', self.name) + self.add_prim_attr('groups', self.group) self.add_prim_attr('offset_a', 0) def infer_shape(self, x_shape, w_shape, b_shape=None): @@ -1587,6 +1588,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): self.add_prim_attr('pad_mode', pad_mode) self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) self.group = validator.check_positive_int(group, 'group', self.name) + self.add_prim_attr('groups', self.group) self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) if context.get_context("device_target") != "GPU" and self.format == "NHWC": raise ValueError("NHWC format only support in GPU target.")