From 7f9bbfd338bacfa2988c9c13d7ff49d4a5c16ec2 Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Tue, 7 Jul 2020 16:03:33 +0800 Subject: [PATCH] add Conv1d ops --- mindspore/nn/layer/conv.py | 346 +++++++++++++++++- mindspore/ops/operations/nn_ops.py | 46 ++- .../gtest_input/pynative/ops_test.py | 32 +- tests/vm_impl/vm_me.py | 58 ++- 4 files changed, 455 insertions(+), 27 deletions(-) diff --git a/mindspore/nn/layer/conv.py b/mindspore/nn/layer/conv.py index 52ec9f2d63..05570ea19b 100644 --- a/mindspore/nn/layer/conv.py +++ b/mindspore/nn/layer/conv.py @@ -18,11 +18,12 @@ from mindspore.ops import operations as P from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from mindspore._checkparam import ParamValidator as validator, Rel +from mindspore._checkparam import Validator from mindspore._checkparam import check_bool, twice, check_int_positive, check_int_non_negative from mindspore._extends import cell_attr_register from ..cell import Cell -__all__ = ['Conv2d', 'Conv2dTranspose', 'DepthwiseConv2d'] +__all__ = ['Conv2d', 'Conv2dTranspose', 'DepthwiseConv2d', 'Conv1d', 'Conv1dTranspose'] class _Conv(Cell): """ @@ -241,6 +242,174 @@ class Conv2d(_Conv): return s +class Conv1d(_Conv): + r""" + 1D convolution layer. + + Applies a 1D convolution over an input tensor which is typically of shape :math:`(N, C_{in}, W_{in})`, + where :math:`N` is batch size and :math:`C_{in}` is channel number. For each batch of shape + :math:`(C_{in}, W_{in})`, the formula is defined as: + + .. math:: + + out_j = \sum_{i=0}^{C_{in} - 1} ccor(W_{ij}, X_i) + b_j, + + where :math:`ccor` is cross correlation operator, :math:`C_{in}` is the input channel number, :math:`j` ranges + from :math:`0` to :math:`C_{out} - 1`, :math:`W_{ij}` corresponds to :math:`i`-th channel of the :math:`j`-th + filter and :math:`out_{j}` corresponds to the :math:`j`-th channel of the output. :math:`W_{ij}` is a slice + of kernel and it has shape :math:`(\text{ks_w})`, where :math:`\text{ks_w}` are width of the convolution kernel. + The full kernel has shape :math:`(C_{out}, C_{in} // \text{group}, \text{ks_w})`, where group is the group number + to split the input in the channel dimension. + + If the 'pad_mode' is set to be "valid", the output width will be + :math:`\left \lfloor{1 + \frac{W_{in} + 2 \times \text{padding} - \text{ks_w} - + (\text{ks_w} - 1) \times (\text{dilation} - 1) }{\text{stride}}} \right \rfloor` respectively. + + The first introduction can be found in paper `Gradient Based Learning Applied to Document Recognition + `_. + + Args: + in_channels (int): The number of input channel :math:`C_{in}`. + out_channels (int): The number of output channel :math:`C_{out}`. + kernel_size (int): The data type is int. Specifies the + width of the 1D convolution window. + stride (int): The distance of kernel moving, an int number that represents + the width of movement. Default: 1. + pad_mode (str): Specifies padding mode. The optional values are + "same", "valid", "pad". Default: "same". + + - same: Adopts the way of completion. Output width will be the same as the input. + Total number of padding will be calculated for horizontal + direction and evenly distributed to left and right if possible. Otherwise, the + last extra padding will be done from the bottom and the right side. If this mode is set, `padding` + must be 0. + + - valid: Adopts the way of discarding. The possibly largest width of output will be return + without padding. Extra pixels will be discarded. If this mode is set, `padding` + must be 0. + + - pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input + Tensor borders. `padding` should be greater than or equal to 0. + + padding (int): Implicit paddings on both sides of the input. Default: 0. + dilation (int): The data type is int. Specifies the dilation rate + to use for dilated convolution. If set to be :math:`k > 1`, there will + be :math:`k - 1` pixels skipped for each sampling location. Its value should + be greater or equal to 1 and bounded by the height and width of the + input. Default: 1. + group (int): Split filter into groups, `in_ channels` and `out_channels` should be + divisible by the number of groups. Default: 1. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. + It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified, + values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well + as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' + and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of + Initializer for more details. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible + Initializer and string are the same as 'weight_init'. Refer to the values of + Initializer for more details. Default: 'zeros'. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, W_{out})`. + + Examples: + >>> net = nn.Conv1d(120, 240, 4, has_bias=False, weight_init='normal') + >>> input = Tensor(np.ones([1, 120, 640]), mindspore.float32) + >>> net(input).shape + (1, 240, 640) + """ + @cell_attr_register + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode='same', + padding=0, + dilation=1, + group=1, + has_bias=False, + weight_init='normal', + bias_init='zeros'): + + Validator.check_value_type("kernel_size", kernel_size, [int], self.cls_name) + Validator.check_value_type("stride", stride, [int], self.cls_name) + Validator.check_value_type("padding", padding, [int], self.cls_name) + Validator.check_value_type("dilation", dilation, [int], self.cls_name) + Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name) + Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name) + Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name) + Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name) + kernel_size = (1, kernel_size) + stride = (1, stride) + dilation = (1, dilation) + + super(Conv1d, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + pad_mode, + padding, + dilation, + group, + has_bias, + weight_init, + bias_init) + self.padding = (0, 0, padding, padding) + self.conv2d = P.Conv2D(out_channel=self.out_channels, + kernel_size=self.kernel_size, + mode=1, + pad_mode=self.pad_mode, + pad=self.padding, + stride=self.stride, + dilation=self.dilation, + group=self.group) + self.bias_add = P.BiasAdd() + if pad_mode not in ('valid', 'same', 'pad'): + raise ValueError('Attr \'pad_mode\' of \'Conv1d\' Op passed ' + + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') + self.expand_dims = P.ExpandDims() + self.squeeze = P.Squeeze(2) + self.shape = P.Shape() + + def construct(self, x): + x_shape = self.shape(x) + if len(x_shape) == 3: + x = self.expand_dims(x, 2) + output = self.conv2d(x, self.weight) + if self.has_bias: + output = self.bias_add(output, self.bias) + if len(x_shape) == 3: + output = self.squeeze(output) + return output + + def extend_repr(self): + s = 'input_channels={}, output_channels={}, kernel_size={},' \ + 'stride={}, pad_mode={}, padding={}, dilation={}, ' \ + 'group={}, has_bias={},' \ + 'weight_init={}, bias_init={}'.format( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.pad_mode, + self.padding, + self.dilation, + self.group, + self.has_bias, + self.weight, + self.bias) + + if self.has_bias: + s += ', bias={}'.format(self.bias) + return s + + class Conv2dTranspose(_Conv): r""" 2D transposed convolution layer. @@ -400,6 +569,181 @@ class Conv2dTranspose(_Conv): return s +class Conv1dTranspose(_Conv): + r""" + 1D transposed convolution layer. + + Compute a 1D transposed convolution, which is also know as a deconvolution + (although it is not actual deconvolution). + + Input is typically of shape :math:`(N, C, W)`, where :math:`N` is batch size and :math:`C` is channel number. + + Args: + in_channels (int): The number of channels in the input space. + out_channels (int): The number of channels in the output space. + kernel_size (int): int, which specifies the width of the 1D convolution window. + stride (int): The distance of kernel moving, an int number that represents + the width of movement. Default: 1. + pad_mode (str): Select the mode of the pad. The optional values are + "pad", "same", "valid". Default: "same". + + - pad: Implicit paddings on both sides of the input. + + - same: Adopted the way of completion. + + - valid: Adopted the way of discarding. + padding (int): Implicit paddings on both sides of the input. Default: 0. + dilation (int): The data type is int. Specifies the dilation rate + to use for dilated convolution. If set to be :math:`k > 1`, there will + be :math:`k - 1` pixels skipped for each sampling location. Its value should + be greater or equal to 1 and bounded by the width of the + input. Default: 1. + group (int): Split filter into groups, `in_channels` and `out_channels` should be + divisible by the number of groups. This is not support for Davinci devices when group > 1. Default: 1. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. + It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified, + values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well + as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones' + and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of + Initializer for more details. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible + Initializer and string are the same as 'weight_init'. Refer to the values of + Initializer for more details. Default: 'zeros'. + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, W_{out})`. + + Examples: + >>> net = nn.Conv1dTranspose(3, 64, 4, has_bias=False, weight_init='normal') + >>> input = Tensor(np.ones([1, 3, 50]), mindspore.float32) + >>> net(input) + """ + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode='same', + padding=0, + dilation=1, + group=1, + has_bias=False, + weight_init='normal', + bias_init='zeros'): + Validator.check_value_type("kernel_size", kernel_size, [int], self.cls_name) + Validator.check_value_type("stride", stride, [int], self.cls_name) + Validator.check_value_type("padding", padding, [int], self.cls_name) + Validator.check_value_type("dilation", dilation, [int], self.cls_name) + Validator.check_integer('kernel_size', kernel_size, 1, Rel.GE, self.cls_name) + Validator.check_integer('stride', stride, 1, Rel.GE, self.cls_name) + Validator.check_integer('padding', padding, 0, Rel.GE, self.cls_name) + Validator.check_integer('dilation', dilation, 1, Rel.GE, self.cls_name) + kernel_size = (1, kernel_size) + stride = (1, stride) + dilation = (1, dilation) + # out_channels and in_channels swap. + # cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel, + # then Conv1dTranspose's out_channel refers to Conv2DBackpropInput's in_channel. + super(Conv1dTranspose, self).__init__( + in_channels, + out_channels, + kernel_size, + stride, + pad_mode, + padding, + dilation, + group, + has_bias, + weight_init, + bias_init, + transposed=True) + self.padding = (0, 0, padding, padding) + self.in_channels = in_channels + self.out_channels = out_channels + self.shape = P.Shape() + if pad_mode not in ('valid', 'same', 'pad'): + raise ValueError('Attr \'pad_mode\' of \'Conv1dTranspose\' Op passed ' + + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') + self.is_valid = self.pad_mode == 'valid' + self.is_same = self.pad_mode == 'same' + self.is_pad = self.pad_mode == 'pad' + if check_bool(has_bias): + self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') + + # cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel. + self.conv2d_transpose = P.Conv2DBackpropInput(out_channel=in_channels, + kernel_size=kernel_size, + mode=1, + pad_mode=pad_mode, + pad=self.padding, + stride=stride, + dilation=dilation, + group=group) + self.bias_add = P.BiasAdd() + self.expand_dims = P.ExpandDims() + self.squeeze = P.Squeeze(2) + + def set_strategy(self, strategy): + self.conv2d_transpose.set_strategy(strategy) + return self + + def _deconv_output_length(self, input_length, filter_size, stride_size, dilation_size, padding): + """Calculate the width and height of output.""" + length = 0 + filter_size = filter_size + (filter_size - 1) * (dilation_size - 1) + if self.is_valid: + if filter_size - stride_size > 0: + length = input_length * stride_size + filter_size - stride_size + else: + length = input_length * stride_size + elif self.is_same: + length = input_length * stride_size + elif self.is_pad: + length = input_length * stride_size - padding + filter_size - stride_size + + return length + + def construct(self, x): + x_shape = self.shape(x) + if len(x_shape) == 3: + x = self.expand_dims(x, 2) + + n, _, h, w = self.shape(x) + + h_out = self._deconv_output_length(h, self.kernel_size[0], self.stride[0], self.dilation[0], + self.padding[0] + self.padding[1]) + w_out = self._deconv_output_length(w, self.kernel_size[1], self.stride[1], self.dilation[1], + self.padding[2] + self.padding[3]) + if self.has_bias: + return self.bias_add(self.conv2d_transpose(x, self.weight, (n, self.out_channels, h_out, w_out)), + self.bias) + output = self.conv2d_transpose(x, self.weight, (n, self.out_channels, h_out, w_out)) + if len(x_shape) == 3: + output = self.squeeze(output) + return output + + def extend_repr(self): + s = 'input_channels={}, output_channels={}, kernel_size={},' \ + 'stride={}, pad_mode={}, padding={}, dilation={}, ' \ + 'group={}, has_bias={},' \ + 'weight_init={}, bias_init={}'.format(self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + self.pad_mode, + self.padding, + self.dilation, + self.group, + self.has_bias, + self.weight, + self.bias) + return s + + class DepthwiseConv2d(Cell): r""" 2D depthwise convolution layer. diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 35860b314e..560f6a1988 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -780,7 +780,9 @@ class Conv2D(PrimitiveWithInfer): mode (int): 0 Math convolutiuon, 1 cross-correlation convolution , 2 deconvolution, 3 depthwise convolution. Default: 1. pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid". - pad (int): The pad value to fill. Default: 0. + pad (Union(int, tuple[int])): The pad value to fill. Default: 0. If `pad` is one integer, the padding of + top, bottom, left and right is same, equal to pad. If `pad` is tuple with four integer, the padding + of top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3] with corresponding. stride (Union(int, tuple[int])): The stride to apply conv filter. Default: 1. dilation (Union(int, tuple[int])): Specify the space to use between kernel elements. Default: 1. group (int): Split input into groups. Default: 1. @@ -820,11 +822,19 @@ class Conv2D(PrimitiveWithInfer): self.add_prim_attr('stride', self.stride) self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) self.add_prim_attr('dilation', self.dilation) - validator.check_value_type('pad', pad, (int,), self.name) + validator.check_value_type('pad', pad, (int, tuple), self.name) + if isinstance(pad, int): + pad = (pad,) * 4 + else: + validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name) + self.padding = pad self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) - self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) + + if pad_mode != 'pad' and pad != (0, 0, 0, 0): + raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.") if self.pad_mode == 'pad': - validator.check_integer('pad', self.pad, 0, Rel.GE, self.name) + for item in pad: + validator.check_integer('pad item', item, 0, Rel.GE, self.name) self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) self.add_prim_attr('data_format', "NCHW") @@ -862,11 +872,11 @@ class Conv2D(PrimitiveWithInfer): pad_left = math.floor(pad_needed_w / 2) pad_right = pad_needed_w - pad_left elif self.pad_mode == 'pad': - pad_top, pad_bottom, pad_left, pad_right = self.pad, self.pad, self.pad, self.pad + pad_top, pad_bottom, pad_left, pad_right = self.padding - h_out = 1 + (x_shape[2] + 2 * self.pad - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \ + h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) \ / stride_h - w_out = 1 + (x_shape[3] + 2 * self.pad - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \ + w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) \ / stride_w h_out = math.floor(h_out) w_out = math.floor(w_out) @@ -1277,7 +1287,9 @@ class Conv2DBackpropInput(PrimitiveWithInfer): out_channel (int): The dimensionality of the output space. kernel_size (Union[int, tuple[int]]): The size of the convolution window. pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid". - pad (int): The pad value to fill. Default: 0. + pad (Union[int, tuple[int]]): The pad value to fill. Default: 0. If `pad` is one integer, the padding of + top, bottom, left and right is same, equal to pad. If `pad` is tuple with four integer, the padding + of top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3] with corresponding. mode (int): 0 Math convolutiuon, 1 cross-correlation convolution , 2 deconvolution, 3 depthwise convolution. Default: 1. stride (Union[int. tuple[int]]): The stride to apply conv filter. Default: 1. @@ -1314,9 +1326,21 @@ class Conv2DBackpropInput(PrimitiveWithInfer): self.add_prim_attr('stride', self.stride) self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) self.add_prim_attr('dilation', self.dilation) - validator.check_value_type('pad', pad, (int,), self.name) + + validator.check_value_type('pad', pad, (int, tuple), self.name) + if isinstance(pad, int): + pad = (pad,) * 4 + self.pad = pad + else: + validator.check_integer('pad size', len(pad), 4, Rel.EQ, self.name) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) - self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) + if pad_mode != 'pad' and pad != (0, 0, 0, 0): + raise ValueError(f"For '{self.name}', padding must be zero when pad_mode is '{pad_mode}'.") + if self.pad_mode == 'pad': + for item in pad: + validator.check_integer('pad item', item, 0, Rel.GE, self.name) + pad_mode = pad_mode.upper() self.add_prim_attr('pad_mode', pad_mode) self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) @@ -1358,7 +1382,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): pad_right = pad_needed_w - pad_left pad_list = (pad_top, pad_bottom, pad_left, pad_right) elif self.pad_mode == 'PAD': - pad_list = (self.pad,) * 4 + pad_list = self.pad self.add_prim_attr('pad_list', pad_list) out = { 'value': None, diff --git a/tests/ut/cpp/python_input/gtest_input/pynative/ops_test.py b/tests/ut/cpp/python_input/gtest_input/pynative/ops_test.py index be31b0f709..d27d6c5eca 100644 --- a/tests/ut/cpp/python_input/gtest_input/pynative/ops_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pynative/ops_test.py @@ -22,11 +22,22 @@ from mindspore.ops.vm_impl_registry import vm_impl_registry as vm_impl_getters def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1): """Rearranges an image to row vector""" + if isinstance(pad, int): + pad_top = pad + pad_bottom = pad + pad_left = pad + pad_right = pad + elif isinstance(pad, tuple) and len(pad) == 4: + pad_top, pad_bottom, pad_left, pad_right = pad + else: + raise ValueError(f"The \'pad\' should be an int number or " + f"a tuple of two or four int numbers, but got {pad}") + batch_num, channel, height, width = img.shape - out_h = (height + 2 * pad - filter_h - (filter_h - 1) * (dilation[2] - 1)) // stride[2] + 1 - out_w = (width + 2 * pad - filter_w - (filter_w - 1) * (dilation[3] - 1)) // stride[3] + 1 + out_h = (height + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation[2] - 1)) // stride[2] + 1 + out_w = (width + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation[3] - 1)) // stride[3] + 1 - img = np.pad(img, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant') + img = np.pad(img, [(0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right)], 'constant') col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype) for y in range(filter_h): @@ -43,10 +54,21 @@ def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1): def conv2d(x, weight, bias=None, stride=1, pad=0, dilation=1, groups=1, padding_mode='zeros'): """Convolution 2D""" + if isinstance(pad, int): + pad_top = pad + pad_bottom = pad + pad_left = pad + pad_right = pad + elif isinstance(pad, tuple) and len(pad) == 4: + pad_top, pad_bottom, pad_left, pad_right = pad + else: + raise ValueError(f"The \'pad\' should be an int number or " + f"a tuple of two or four int numbers, but got {pad}") + batch_num, _, x_h, x_w = x.shape filter_num, _, filter_h, filter_w = weight.shape - out_h = 1 + int((x_h + 2 * pad - filter_h - (filter_h - 1) * (dilation[2] - 1)) / stride[2]) - out_w = 1 + int((x_w + 2 * pad - filter_w - (filter_w - 1) * (dilation[3] - 1)) / stride[3]) + out_h = 1 + int((x_h + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation[2] - 1)) / stride[2]) + out_w = 1 + int((x_w + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation[3] - 1)) / stride[3]) col = im2col(x, filter_h, filter_w, stride, pad, dilation) col_w = np.reshape(weight, (filter_num, -1)).T out = np.dot(col, col_w) diff --git a/tests/vm_impl/vm_me.py b/tests/vm_impl/vm_me.py index d9973787ba..89cc1569a9 100644 --- a/tests/vm_impl/vm_me.py +++ b/tests/vm_impl/vm_me.py @@ -169,16 +169,32 @@ def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0): raise ValueError(f"The \'stride\' should be an int number or " f"a tuple of two or four int numbers, but got {stride}") + if isinstance(pad, int): + pad_top = pad + pad_bottom = pad + pad_left = pad + pad_right = pad + elif isinstance(pad, tuple) and len(pad) == 2: + pad_top = pad[0] + pad_bottom = pad[0] + pad_left = pad[1] + pad_right = pad[1] + elif isinstance(pad, tuple) and len(pad) == 4: + pad_top, pad_bottom, pad_left, pad_right = pad + else: + raise ValueError(f"The \'pad\' should be an int number or " + f"a tuple of two or four int numbers, but got {pad}") + batch_num, channel, height, width = input_shape - out_h = (height + 2 * pad - filter_h) // stride_h + 1 - out_w = (width + 2 * pad - filter_w) // stride_w + 1 + out_h = (height + pad_top + pad_bottom - filter_h) // stride_h + 1 + out_w = (width + pad_left + pad_right - filter_w) // stride_w + 1 col = col.reshape(batch_num, out_h, out_w, channel, filter_h, filter_w) \ .transpose(0, 3, 4, 5, 1, 2) img = np.zeros((batch_num, channel, - height + 2 * pad + stride_h - 1, - width + 2 * pad + stride_w - 1)) \ + height + pad_top + pad_bottom + stride_h - 1, + width + pad_left + pad_right + stride_w - 1)) \ .astype(col.dtype) for y in range(filter_h): y_max = y + stride_h * out_h @@ -186,7 +202,7 @@ def col2im(col, input_shape, filter_h, filter_w, stride=1, pad=0): x_max = x + stride_h * out_w img[:, :, y:y_max:stride_h, x:x_max:stride_h] += col[:, :, y, x, :, :] - return img[:, :, pad:height + pad, pad:width + pad] + return img[:, :, pad_top:height + pad_bottom, pad_left:width + pad_right] def convolve(x, w, b=None, pad_mode="valid"): @@ -243,10 +259,21 @@ def conv2d(x, weight, bias=None, stride=1, pad=0, dilation_h = dilation[0] dilation_w = dilation[1] + if isinstance(pad, int): + pad_top = pad + pad_bottom = pad + pad_left = pad + pad_right = pad + elif isinstance(pad, tuple) and len(pad) == 4: + pad_top, pad_bottom, pad_left, pad_right = pad + else: + raise ValueError(f"The \'pad\' should be an int number or " + f"a tuple of two or four int numbers, but got {pad}") + batch_num, _, x_h, x_w = x.shape filter_num, _, filter_h, filter_w = weight.shape - out_h = 1 + int((x_h + 2 * pad - filter_h - (filter_h - 1) * (dilation_h - 1)) / stride_h) - out_w = 1 + int((x_w + 2 * pad - filter_w - (filter_w - 1) * (dilation_w - 1)) / stride_w) + out_h = 1 + int((x_h + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation_h - 1)) / stride_h) + out_w = 1 + int((x_w + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation_w - 1)) / stride_w) col = im2col(x, filter_h, filter_w, stride, pad, dilation) col_w = np.reshape(weight, (filter_num, -1)).T out = np.dot(col, col_w) @@ -348,11 +375,22 @@ def im2col(img, filter_h, filter_w, stride=1, pad=0, dilation=1): raise ValueError(f"The \'dilation\' should be an int number or " f"a tuple of two or four int numbers, but got {dilation}") + if isinstance(pad, int): + pad_top = pad + pad_bottom = pad + pad_left = pad + pad_right = pad + elif isinstance(pad, tuple) and len(pad) == 4: + pad_top, pad_bottom, pad_left, pad_right = pad + else: + raise ValueError(f"The \'pad\' should be an int number or " + f"a tuple of two or four int numbers, but got {pad}") + batch_num, channel, height, width = img.shape - out_h = (height + 2 * pad - filter_h - (filter_h - 1) * (dilation_h - 1)) // stride_h + 1 - out_w = (width + 2 * pad - filter_w - (filter_w - 1) * (dilation_w - 1)) // stride_w + 1 + out_h = (height + pad_top + pad_bottom - filter_h - (filter_h - 1) * (dilation_h - 1)) // stride_h + 1 + out_w = (width + pad_left + pad_right - filter_w - (filter_w - 1) * (dilation_w - 1)) // stride_w + 1 - img = np.pad(img, [(0, 0), (0, 0), (pad, pad), (pad, pad)], 'constant') + img = np.pad(img, [(0, 0), (0, 0), (pad_top, pad_bottom), (pad_left, pad_right)], 'constant') col = np.zeros((batch_num, channel, filter_h, filter_w, out_h, out_w)).astype(img.dtype) for y in range(filter_h):