|
|
|
@ -78,6 +78,17 @@ def _check_shape(arg_name, arg_value, prim_name):
|
|
|
|
|
return arg_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _update_attr_by_format(arg_value, arg_format):
|
|
|
|
|
"""
|
|
|
|
|
If the format is NHWC, should modify the strides or dilation shape.
|
|
|
|
|
"""
|
|
|
|
|
ret = arg_value
|
|
|
|
|
if len(arg_value) == 4 and arg_format == "NHWC":
|
|
|
|
|
ret = arg_value[1:] + (1,)
|
|
|
|
|
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Flatten(PrimitiveWithInfer):
|
|
|
|
|
r"""
|
|
|
|
|
Flattens a tensor without changing its batch size on the 0-th axis.
|
|
|
|
@ -2157,9 +2168,15 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
|
|
|
|
|
self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output'])
|
|
|
|
|
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
|
|
|
|
|
self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name)
|
|
|
|
|
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
|
|
|
|
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
|
|
|
|
raise ValueError("NHWC format only support in GPU target.")
|
|
|
|
|
self.add_prim_attr('data_format', self.format)
|
|
|
|
|
self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True)
|
|
|
|
|
self.stride = _update_attr_by_format(self.stride, self.format)
|
|
|
|
|
self.add_prim_attr('stride', self.stride)
|
|
|
|
|
self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True)
|
|
|
|
|
self.dilation = _update_attr_by_format(self.dilation, self.format)
|
|
|
|
|
self.add_prim_attr('dilation', self.dilation)
|
|
|
|
|
validator.check_value_type('pad', pad, (int, tuple), self.name)
|
|
|
|
|
if isinstance(pad, int):
|
|
|
|
@ -2180,10 +2197,6 @@ class Conv2DBackpropInput(PrimitiveWithInfer):
|
|
|
|
|
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
|
|
|
|
|
self.group = validator.check_positive_int(group, 'group', self.name)
|
|
|
|
|
self.add_prim_attr('groups', self.group)
|
|
|
|
|
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name)
|
|
|
|
|
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
|
|
|
|
raise ValueError("NHWC format only support in GPU target.")
|
|
|
|
|
self.add_prim_attr('data_format', self.format)
|
|
|
|
|
if pad_list:
|
|
|
|
|
for x in pad_list:
|
|
|
|
|
validator.check_non_negative_int(x, 'element of pad_list', self.name)
|
|
|
|
|