|
|
|
@ -1043,6 +1043,7 @@ class Conv2D(PrimitiveWithInfer):
|
|
|
|
|
self.add_prim_attr('data_format', self.format)
|
|
|
|
|
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
|
|
|
|
|
self.group = validator.check_positive_int(group, 'group', self.name)
|
|
|
|
|
self.add_prim_attr('groups', self.group)
|
|
|
|
|
self.add_prim_attr('offset_a', 0)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape, w_shape, b_shape=None):
|
|
|
|
@ -1587,6 +1588,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
|
|
|
|
|
self.add_prim_attr('pad_mode', pad_mode)
|
|
|
|
|
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
|
|
|
|
|
self.group = validator.check_positive_int(group, 'group', self.name)
|
|
|
|
|
self.add_prim_attr('groups', self.group)
|
|
|
|
|
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
|
|
|
|
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
|
|
|
|
raise ValueError("NHWC format only support in GPU target.")
|
|
|
|
|