|
|
|
@ -1904,7 +1904,7 @@ class MaxPoolWithArgmax(_Pool):
|
|
|
|
|
|
|
|
|
|
class MaxPool3D(PrimitiveWithInfer):
|
|
|
|
|
r"""
|
|
|
|
|
Max pooling operation.
|
|
|
|
|
3D max pooling operation.
|
|
|
|
|
|
|
|
|
|
Applies a 3D max pooling over an input Tensor which can be regarded as a composition of 3D planes.
|
|
|
|
|
|
|
|
|
@ -1947,7 +1947,7 @@ class MaxPool3D(PrimitiveWithInfer):
|
|
|
|
|
TypeError: If `pad_mode` or `data_format` is not a string.
|
|
|
|
|
ValueError: If numbers in `kernel_size` or `strides` are not positive.
|
|
|
|
|
ValueError: If `pad_mode` is not one of 'same', 'valid'.
|
|
|
|
|
ValueError: If `kernel_size` or `strides` is a tuple whose length is not equal to 3 or 5.
|
|
|
|
|
ValueError: If `kernel_size` or `strides` is a tuple whose length is not equal to 3.
|
|
|
|
|
ValueError: If `data_format` is not 'NCDHW'.
|
|
|
|
|
|
|
|
|
|
Supported Platforms:
|
|
|
|
@ -1971,9 +1971,10 @@ class MaxPool3D(PrimitiveWithInfer):
|
|
|
|
|
self.pad_mode = validator.check_string(pad_mode.upper(), ['VALID', 'SAME'], 'pad_mode', self.name)
|
|
|
|
|
self.add_prim_attr("pad_mode", self.pad_mode)
|
|
|
|
|
self.data_format = validator.check_string(data_format, ['NCDHW'], 'data_format', self.name)
|
|
|
|
|
self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name, allow_five=True, ret_five=True)
|
|
|
|
|
self.kernel_size = _check_3d_int_or_tuple("kernel_size", kernel_size, self.name,
|
|
|
|
|
allow_five=False, ret_five=True)
|
|
|
|
|
self.add_prim_attr("kernel_size", self.kernel_size)
|
|
|
|
|
self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=True, ret_five=True)
|
|
|
|
|
self.strides = _check_3d_int_or_tuple("strides", strides, self.name, allow_five=False, ret_five=True)
|
|
|
|
|
self.add_prim_attr("strides", self.strides)
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, x_shape):
|
|
|
|
@ -2274,7 +2275,7 @@ class BiasAdd(PrimitiveWithCheck):
|
|
|
|
|
self.add_prim_attr('data_format', self.format)
|
|
|
|
|
|
|
|
|
|
def check_shape(self, x_shape, b_shape):
|
|
|
|
|
validator.check_int(len(x_shape), 2, Rel.GE, "x rank", self.name)
|
|
|
|
|
validator.check_int_range(len(x_shape), 2, 5, Rel.INC_BOTH, "x rank", self.name)
|
|
|
|
|
if self.format == "NCDHW" and (len(x_shape) != 5 or context.get_context("device_target") != "Ascend"):
|
|
|
|
|
raise ValueError("NCDHW format only support 5-dims input in Ascend target.")
|
|
|
|
|
validator.check_equal_int(len(b_shape), 1, "bias rank", self.name)
|
|
|
|
|