fix 3d norm

pull/12906/head
jiangzhenguang 4 years ago
parent 7ba21f8d8c
commit ef1b98bf18

@ -98,9 +98,12 @@ def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret
Checks whether an argument is a positive int or tuple with 3 or 5(when allow_five is True) positive int elements.
"""
def _raise_message(third_one=False):
if third_one:
raise ValueError(f"For '{prim_name}' the depth of attr '{arg_name}' should be 1, but got {arg_value}")
def _raise_message(third_one_flag=False, three_input_flag=False):
if third_one_flag:
raise ValueError(f"For '{prim_name}' the depth of attr '{arg_name}' should be 1, but got {arg_value[-3]}")
if three_input_flag:
raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of "
f"three positive int numbers, but got {arg_value}")
raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of three "
f"{'or five ' if allow_five else ''}positive int numbers, but got {arg_value}")
@ -121,7 +124,8 @@ def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret
Validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name)
if three_input and isinstance(arg_value, tuple):
Validator.check_equal_int(len(arg_value), 3, arg_name, prim_name)
if len(arg_value) != 3:
_raise_message(three_input_flag=three_input)
ret_value = _get_return_value()
for item in ret_value:
if isinstance(item, int) and not isinstance(item, bool):
@ -133,7 +137,7 @@ def _check_3d_int_or_tuple(arg_name, arg_value, prim_name, allow_five=False, ret
if third_one:
if ret_value[-3] != 1:
_raise_message(third_one)
_raise_message(third_one_flag=third_one)
return tuple(ret_value)

@ -441,7 +441,7 @@ class Conv3DBackpropFilter(PrimitiveWithInfer):
out = {
'value': None,
'shape': w_size_v,
'dtype': x['dtype'],
'dtype': mstype.float32,
}
return out

@ -7425,20 +7425,20 @@ class Conv3D(PrimitiveWithInfer):
the depth, height and width of movement are both strides, or a tuple of three int numbers that
represent depth, height and width of movement respectively. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are
"same", "valid", "pad". Default: "same".
"same", "valid", "pad". Default: "valid".
- same: Adopts the way of completion. The depth, height and width of the output will be the same as
the input. The total number of padding will be calculated in depth, horizontal and vertical
directions and evenly distributed to head and tail, top and bottom, left and right if possible.
Otherwise, the last extra padding will be done from the tail, bottom and the right side.
If this mode is set, `padding` must be 0.
If this mode is set, `pad` must be 0.
- valid: Adopts the way of discarding. The possible largest depth, height and width of output
will be returned without padding. Extra pixels will be discarded. If this mode is set, `padding`
will be returned without padding. Extra pixels will be discarded. If this mode is set, `pad`
must be 0.
- pad: Implicit paddings on both sides of the input. The number of `padding` will be padded to the input
Tensor borders. `padding` must be greater than or equal to 0.
- pad: Implicit paddings on both sides of the input in depth, height, width. The number of `pad` will
be padded to the input Tensor borders. `pad` must be greater than or equal to 0.
pad (Union(int, tuple[int])): The pad value to be filled. Default: 0. If `pad` is an integer, the paddings of
head, tail, top, bottom, left and right are the same, equal to pad. If `pad` is a tuple of six
@ -7502,16 +7502,17 @@ class Conv3D(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'w'], outputs=['output'])
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True,
ret_five=True, three_input=True)
ret_five=True)
self.add_prim_attr('strides', self.stride)
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True,
ret_five=True, third_one=True, three_input=True)
ret_five=True, third_one=True)
self.add_prim_attr('dilations', self.dilation)
validator.check_value_type('pad', pad, (int, tuple), self.name)
if isinstance(pad, int):
pad = (pad,) * 6
if len(pad) != 6:
raise ValueError(f'the size of pad in `conv3d` must be 6, but got `{len(pad)}`.')
raise ValueError(f"For `conv3d` attr 'pad' should be an positive int number or a tuple of "
f"six positive int numbers, but got `{len(pad)}`.")
self.add_prim_attr("pad", pad)
self.padding = pad
validator.check_int_range(self.padding[0], 0, self.kernel_size[0], Rel.INC_LEFT,
@ -7853,16 +7854,17 @@ class Conv3DTranspose(PrimitiveWithInfer):
self.add_prim_attr('out_channel', self.out_channel)
self.kernel_size = _check_3d_int_or_tuple('kernel_size', kernel_size, self.name)
self.stride = _check_3d_int_or_tuple('stride', stride, self.name, allow_five=True,
ret_five=True, three_input=True)
ret_five=True)
self.add_prim_attr('strides', self.stride)
self.dilation = _check_3d_int_or_tuple('dilation', dilation, self.name, allow_five=True,
ret_five=True, third_one=True, three_input=True)
ret_five=True, third_one=True)
self.add_prim_attr('dilations', self.dilation)
validator.check_value_type('pad', pad, (int, tuple), self.name)
if isinstance(pad, int):
pad = (pad,) * 6
if len(pad) != 6:
raise ValueError(f'the size of pad in `conv3d` must be 6, but got `{len(pad)}`.')
raise ValueError(f"For `conv3d` attr 'pad' should be an positive int number or a tuple of "
f"six positive int numbers, but got `{len(pad)}`.")
self.pad_list = pad
for item in self.pad_list:
validator.check_non_negative_int(item, 'pad item', self.name)

@ -344,7 +344,7 @@ class ResNet(nn.Cell):
num_classes (int): The number of classes that the training images are belonging to.
use_se (bool): Enable SE-ResNet50 net. Default: False.
se_block(bool): Use se block in SE-ResNet50 net in layer 3 and layer 4. Default: False.
res_base (bool): Enable parameter setting of resnet18. Default: True.
res_base (bool): Enable parameter setting of resnet18. Default: False.
Returns:
Tensor, output tensor.

Loading…
Cancel
Save