add conv3d_transpose.

pull/9578/head
jiangzhenguang 4 years ago
parent 80fe11ac7b
commit 5060565920

@ -65,12 +65,12 @@ def get_bprop_conv2d(self):
def get_bprop_conv3d(self):
"""Grad definition for `Conv3D` operation."""
input_grad = nps.Conv3DBackpropInput(
self.out_channel, self.kernel_size, self.pad_mode, self.pad, mode=self.mode,
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
self.out_channel, self.kernel_size, self.mode, pad_mode=self.pad_mode,
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
)
filter_grad = G.Conv3DBackpropFilter(
self.out_channel, self.kernel_size, self.pad_mode, self.pad, mode=self.mode,
dilation=self.dilation, stride=self.stride, group=self.group, data_format=self.format
self.out_channel, self.kernel_size, self.mode, pad_mode=self.pad_mode,
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
)
get_shape = P.Shape()
@ -82,6 +82,27 @@ def get_bprop_conv3d(self):
return bprop
@bprop_getters.register(nps.Conv3DTranspose)
def get_bprop_conv3d_transpose(self):
"""Grad definition for `Conv3DTranspose` operation."""
filter_grad = G.Conv3DBackpropFilter(
out_channel=self.out_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode=self.pad_mode,
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
)
input_grad = nps.Conv3D(
out_channel=self.out_channel, kernel_size=self.kernel_size, mode=self.mode, pad_mode=self.pad_mode,
pad=self.pad, stride=self.stride, dilation=self.dilation, group=self.group, data_format=self.data_format
)
input_size = self.input_size
def bprop(x, w, out, dout):
dx = input_grad(dout, w)
dw = filter_grad(dout, x, F.shape(w))
return dx, dw, zeros_like(input_size)
return bprop
@bprop_getters.register(inner.ExtractImagePatches)
def get_bprop_extract_image_patches(self):
"""Grad definition for `ExtractImagePatches` operation."""

@ -31,9 +31,9 @@ conv3d_transpose_op_info = TBERegOp("Conv3DTranspose") \
.attr("data_format", "optional", "str", "all") \
.attr("output_padding", "optional", "listInt", "all") \
.input(0, "x", False, "required", "all") \
.input(0, "filter", False, "required", "all") \
.input(0, "bias", False, "optional", "all") \
.input(1, "offset_w", False, "optional", "all") \
.input(1, "filter", False, "required", "all") \
.input(2, "bias", False, "optional", "all") \
.input(3, "offset_w", False, "optional", "all") \
.output(0, "y", True, "required", "all") \
.dtype_format(DataType.F16_NDC1HWC0, DataType.F16_FRACTAL_Z_3D, DataType.F16_Default, DataType.I8_Default,
DataType.F16_NDC1HWC0) \

@ -335,9 +335,9 @@ class Conv3DBackpropFilter(PrimitiveWithInfer):
def __init__(self,
out_channel,
kernel_size,
mode=1,
pad_mode="valid",
pad=0,
mode=1,
stride=(1, 1, 1, 1, 1),
dilation=(1, 1, 1, 1, 1),
group=1,
@ -366,6 +366,7 @@ class Conv3DBackpropFilter(PrimitiveWithInfer):
self.add_prim_attr('pad_mode', self.pad_mode)
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
self.add_prim_attr('mode', self.mode)
self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('groups', self.group)
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)

@ -6768,8 +6768,8 @@ class Conv3D(PrimitiveWithInfer):
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
- **weight** (Tensor) - Set size of kernel is :math:`(D_1, K_2, K_3)`, then the shape is
:math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`.
- **weight** (Tensor) - Set size of kernel is :math:`(D_in, K_h, K_w)`, then the shape is
:math:`(C_{out}, C_{in}, D_{in}, K_h, K_w)`.
Outputs:
Tensor, the value that applied 3D convolution. The shape is :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`.
@ -6819,6 +6819,7 @@ class Conv3D(PrimitiveWithInfer):
validator.check_non_negative_int(item, 'pad item', self.name)
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
self.add_prim_attr('mode', self.mode)
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
self.add_prim_attr('data_format', self.format)
self.add_prim_attr('io_format', "NCDHW")
@ -6916,8 +6917,8 @@ class Conv3DBackpropInput(PrimitiveWithInfer):
data_format (str): The optional value for data format. Currently only support 'NCDHW'.
Inputs:
- **weight** (Tensor) - Set size of kernel is :math:`(K_1, K_2, K_3)`, then the shape is
:math:`(C_{out}, C_{in}, D_{in}, K_1, K_2)`.
- **weight** (Tensor) - Set size of kernel is :math:`(D_in, K_h, K_w)`, then the shape is
:math:`(C_{out}, C_{in}, D_{in}, K_h, K_w)`.
- **dout** (Tensor) - the gradients w.r.t the output of the convolution. The shape conforms to the default
data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`.
- **input_size** (Tensor) - A tuple describes the shape of the input which conforms to the format
@ -6943,9 +6944,9 @@ class Conv3DBackpropInput(PrimitiveWithInfer):
def __init__(self,
out_channel,
kernel_size,
mode=1,
pad_mode="valid",
pad=0,
mode=1,
stride=1,
dilation=1,
group=1,
@ -6973,6 +6974,7 @@ class Conv3DBackpropInput(PrimitiveWithInfer):
self.add_prim_attr('pad_mode', self.pad_mode)
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
self.add_prim_attr('mode', self.mode)
self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('groups', self.group)
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
@ -7026,3 +7028,169 @@ class Conv3DBackpropInput(PrimitiveWithInfer):
'dtype': doutput['dtype'],
}
return out
class Conv3DTranspose(PrimitiveWithInfer):
"""
Computes the gradients of convolution 3D with respect to the input.
Args:
input_size (tuple[int]): The shape of the output with five integers. If input_ Size is set to (0, 0, 0, 0, 0),
it will activate output_padding function. Otherwise, the output shape will be the same as the input_size,
and the output_padding setting will be invalid, and pad_mode cannot set as 'same'. Default: (0, 0, 0, 0, 0).
out_channel (int): The dimension of the output.
kernel_size (Union[int, tuple[int]]): The kernel size of the 3D convolution.
mode (int): Modes for different convolutions. Not currently used.
pad_mode (str): Modes to fill padding. It could be "valid", "same", or "pad". Default: "valid".
pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of four integers,
the padding of head, tail, top, bottom, left and right equal to pad[0], pad[1], pad[2], pad[3], pad[4]
and pad[5] correspondingly.
stride (Union(int, tuple[int])): The stride to be applied to the convolution filter. Default: 1.
dilation (Union(int, tuple[int])): Specifies the space to use between kernel elements. Default: 1.
group (int): Splits input into groups. Default: 1.
output_padding (Union(int, tuple[int])): Add extra size to each dimension of the output. Default: 0.
data_format (str): The optional value for data format. Currently only support 'NCDHW'.
Inputs:
- **dout** (Tensor) - the gradients w.r.t the output of the convolution. The shape conforms to the default
data_format :math:`(N, C_{out}, D_{out}, H_{out}, W_{out})`.
- **weight** (Tensor) - Set size of kernel is :math:`(D_in, K_h, K_w)`, then the shape is
:math:`(C_{out}, C_{in}, D_{in}, K_h, K_w)`.
- **input_size** (Tensor) - A tuple describes the shape of the input which conforms to the format
:math:`(N, C_{in}, D_{in}, H_{in}, W_{in})`.
Outputs:
Tensor, the gradients w.r.t the input of convolution 3D. It has the same shape as the input.
Supported Platforms:
``Ascend``
Examples:
>>> input_x = Tensor(np.ones([32, 16, 10, 32, 32]), mindspore.float32)
>>> weight = Tensor(np.ones([16, 3, 4, 6, 2]), mindspore.float32)
>>> conv3d_transpose = P.Conv3DTranspose(out_channel=4, kernel_size=(4, 6, 2))
>>> output = conv3d_transpose(input_x, weight)
>>> print(output.shape)
(32, 3, 13, 37, 33)
"""
@prim_attr_register
def __init__(self,
out_channel,
kernel_size,
input_size=(0, 0, 0, 0, 0),
mode=1,
pad_mode="valid",
pad=0,
stride=1,
dilation=1,
group=1,
output_padding=0,
data_format="NCDHW"):
"""Initialize Conv3DTranspose"""
self.init_prim_io_names(inputs=['x', 'filter', 'input_size'], outputs=['output'])
self.input_size = validator.check_value_type('input_size', input_size, [tuple], self.name)
validator.check_equal_int(len(self.input_size), 5, 'input_size', self.name)
for i, dim_len in enumerate(self.input_size):
validator.check_value_type("input_size[%d]" % i, dim_len, [int], self.name)
self.add_prim_attr('input_size', self.input_size)
self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name)
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True, ret_five=True)
self.add_prim_attr('strides', self.stride)
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True, ret_five=True)
self.add_prim_attr('dilations', self.dilation)
validator.check_value_type('pad', pad, (int, tuple), self.name)
if isinstance(pad, int):
pad = (pad,) * 6
validator.check_equal_int(len(pad), 6, 'pad size', self.name)
self.pad_list = pad
self.pad_mode = validator.check_string(pad_mode.lower(), ['valid', 'same', 'pad'], 'pad_mode', self.name)
if self.pad_mode != 'pad' and self.pad_list != (0, 0, 0, 0, 0, 0):
raise ValueError(f"For '{self.name}', when pad is not 0, pad_mode should be set as 'pad'.")
if self.pad_mode == 'pad':
for item in pad:
validator.check_non_negative_int(item, 'pad item', self.name)
self.add_prim_attr('pad_mode', self.pad_mode)
self.mode = validator.check_equal_int(mode, 1, 'mode', self.name)
self.add_prim_attr('mode', self.mode)
self.group = validator.check_positive_int(group, 'group', self.name)
self.add_prim_attr('groups', self.group)
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.name)
self.add_prim_attr('data_format', self.format)
self.add_prim_attr('io_format', "NCDHW")
self.output_padding = _check_3d_int_or_tuple('output_padding', output_padding, self.name,
allow_five=True, ret_five=True, greater_zero=False)
self.add_prim_attr('output_padding', self.output_padding)
def __infer__(self, x, w, b=None):
args = {'x': x['dtype'], 'w': w['dtype']}
if b is not None:
args = {'x': x['dtype'], 'w': w['dtype'], 'b': b['dtype']}
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
output_shape = self.input_size
# infer shape
x_shape = x['shape']
kernel_d = self.kernel_size[0]
kernel_h = self.kernel_size[1]
kernel_w = self.kernel_size[2]
stride_d = self.stride[2]
stride_h = self.stride[3]
stride_w = self.stride[4]
dilation_d = self.dilation[2]
dilation_h = self.dilation[3]
dilation_w = self.dilation[4]
if self.input_size != (0, 0, 0, 0, 0):
# The pad_mode is valid by default. If pad_mode is not valid or same, then pad.
if self.pad_mode == "valid":
self.pad_list = (0, 0, 0, 0, 0, 0)
if self.pad_mode == "same":
pad_needed_d = max(0, (x_shape[2] - 1) * stride_d + dilation_d *
(kernel_d - 1) + 1 - self.input_size[2])
pad_head = math.floor(pad_needed_d / 2)
pad_tail = pad_needed_d - pad_head
pad_needed_h = max(0, (x_shape[3] - 1) * stride_h + dilation_h *
(kernel_h - 1) + 1 - self.input_size[3])
pad_top = math.floor(pad_needed_h / 2)
pad_bottom = pad_needed_h - pad_top
pad_needed_w = max(0, (x_shape[4] - 1) * stride_w + dilation_w *
(kernel_w - 1) + 1 - self.input_size[4])
pad_left = math.floor(pad_needed_w / 2)
pad_right = pad_needed_w - pad_left
self.pad_list = (pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right)
self.add_prim_attr('pads', self.pad_list)
else:
self.add_prim_attr('pads', self.pad_list)
if self.pad_mode == 'same':
raise ValueError("When input_size is (0, 0, 0, 0, 0), the pad_mode cannot be 'same'!")
pad_head, pad_tail, pad_top, pad_bottom, pad_left, pad_right = self.pad_list
w_shape = w['shape']
self.output_padding = self.output_padding if self.format == "NCDHW" else \
(self.output_padding[0], self.output_padding[4], self.output_padding[1],
self.output_padding[2], self.output_padding[3])
d_out = (x_shape[2] - 1) * stride_d - 2 * (pad_head + pad_tail) + dilation_d * \
(kernel_d - 1) + self.output_padding[2] + 1
h_out = (x_shape[3] - 1) * stride_h - 2 * (pad_top + pad_bottom) + dilation_h * \
(kernel_h - 1) + self.output_padding[3] + 1
w_out = (x_shape[4] - 1) * stride_w - 2 * (pad_left + pad_right) + dilation_w * \
(kernel_w - 1) + self.output_padding[4] + 1
output_shape = (x_shape[0], w_shape[1], d_out, h_out, w_out)
self.add_prim_attr('input_size', output_shape)
validator.check("filter's channel", w['shape'][1], "input_size's channel", output_shape[1], Rel.EQ, self.name)
validator.check("filter's batch", w['shape'][0], "input x's channel", x['shape'][1], Rel.EQ, self.name)
validator.check("input_size's batch", output_shape[0], "x's batch", x['shape'][0], Rel.EQ, self.name)
out = {
'value': None,
'shape': output_shape,
'dtype': x['dtype'],
}
return out

@ -444,8 +444,8 @@ class Conv3DBackpropInput(nn.Cell):
def __init__(self, input_shape, out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group,
data_format):
super(Conv3DBackpropInput, self).__init__()
self.conv = nps.Conv3DBackpropInput(out_channel, kernel_size, pad_mode=pad_mode,
pad=pad, mode=mode, stride=stride, dilation=dilation,
self.conv = nps.Conv3DBackpropInput(out_channel=out_channel, kernel_size=kernel_size, mode=mode,
pad_mode=pad_mode, pad=pad, stride=stride, dilation=dilation,
group=group, data_format=data_format)
self.x_size = input_shape
@ -459,8 +459,8 @@ class Conv3DBackpropFilter(nn.Cell):
def __init__(self, w_shape, out_channel, kernel_size, mode, pad_mode, pad, stride, dilation, group, data_format):
super(Conv3DBackpropFilter, self).__init__()
self.conv = G.Conv3DBackpropFilter(out_channel, kernel_size, pad_mode=pad_mode,
pad=pad, mode=mode, stride=stride, dilation=dilation,
self.conv = G.Conv3DBackpropFilter(out_channel=out_channel, kernel_size=kernel_size, mode=mode,
pad_mode=pad_mode, pad=pad, stride=stride, dilation=dilation,
group=group, data_format=data_format)
self.w_size = w_shape
@ -469,6 +469,20 @@ class Conv3DBackpropFilter(nn.Cell):
return ms_out
class Conv3DTranspose(nn.Cell):
"""Conv3DTranspose net definition"""
def __init__(self, out_channel, kernel_size, input_size, mode, pad_mode, pad, stride, dilation, group, data_format):
super(Conv3DTranspose, self).__init__()
self.conv = nps.Conv3DTranspose(out_channel=out_channel, kernel_size=kernel_size, input_size=input_size,
mode=mode, pad_mode=pad_mode, pad=pad, stride=stride, dilation=dilation,
group=group, data_format=data_format)
def construct(self, x, w):
ms_out = self.conv(x, w)
return ms_out
class ApplyFtrlNet(nn.Cell):
def __init__(self):
super(ApplyFtrlNet, self).__init__()
@ -1244,6 +1258,12 @@ test_case_math_ops = [
'desc_inputs': [Tensor(np.random.random((16, 32, 13, 37, 33)).astype(np.float16)),
Tensor(np.random.random((16, 32, 10, 32, 32)).astype(np.float16))],
'skip': ['backward']}),
('Conv3DTranspose', {
'block': Conv3DTranspose(out_channel=3, kernel_size=(4, 6, 2), input_size=(0, 0, 0, 0, 0), mode=1,
pad_mode='valid', pad=0, stride=1, dilation=1, group=1, data_format="NCDHW"),
'desc_inputs': [Tensor(np.random.random((32, 3, 10, 32, 32)).astype(np.float16)),
Tensor(np.random.random((3, 3, 4, 6, 2)).astype(np.float16))],
'skip': ['backward']}),
('CountNonZero', {
'block': CountNonZero(axis=(), keep_dims=False, dtype=mstype.int32),
'desc_inputs': [Tensor(np.random.rand(3, 16, 5, 4).astype(np.float32))],

Loading…
Cancel
Save