!8093 Fix Conv2D op group attr problem.

Merge pull request !8093 from liangchenghui/fix_conv2d
pull/8093/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 167b17299d

@ -354,12 +354,12 @@ class AvgPool1d(_PoolNd):
kernel_size=1,
stride=1,
pad_mode="valid"):
super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode)
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
validator.check_value_type('stride', stride, [int], self.cls_name)
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.cls_name)
validator.check_int(kernel_size, 1, Rel.GE, "kernel_size", self.cls_name)
validator.check_int(stride, 1, Rel.GE, "stride", self.cls_name)
super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode)
self.kernel_size = (1, kernel_size)
self.stride = (1, stride)
self.avg_pool = P.AvgPool(ksize=self.kernel_size,

@ -27,7 +27,8 @@ conv2d_op_info = TBERegOp("Conv2D") \
.attr("stride", "required", "listInt", "all") \
.attr("pad_list", "required", "listInt", "all") \
.attr("dilation", "required", "listInt", "all") \
.attr("offset_a", "optional", "int", "all") \
.attr("groups", "optional", "int", "all") \
.attr("data_format", "optional", "str", "all") \
.input(0, "x", False, "required", "all") \
.input(1, "filter", False, "required", "all") \
.input(2, "bias", False, "optional", "all") \

@ -27,7 +27,7 @@ conv2d_backprop_input_op_info = TBERegOp("Conv2DBackpropInput") \
.attr("stride", "required", "listInt", "all") \
.attr("pad_list", "required", "listInt", "all") \
.attr("dilation", "required", "listInt", "all") \
.attr("group", "optional", "int", "all") \
.attr("groups", "optional", "int", "all") \
.attr("data_format", "optional", "str", "all") \
.input(0, "out_backprop", False, "required", "all") \
.input(1, "filter", False, "required", "all") \

@ -1043,6 +1043,7 @@ class Conv2D(PrimitiveWithInfer):
self.add_prim_attr('data_format', self.format)
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('groups', self.group)
self.add_prim_attr('offset_a', 0)
def infer_shape(self, x_shape, w_shape, b_shape=None):
@ -1587,6 +1588,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
self.add_prim_attr('pad_mode', pad_mode)
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('groups', self.group)
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
raise ValueError("NHWC format only support in GPU target.")

Loading…
Cancel
Save