|
|
|
@ -18,6 +18,7 @@ from mindspore.ops import functional as F
|
|
|
|
|
from mindspore._checkparam import Validator as validator
|
|
|
|
|
from ... import context
|
|
|
|
|
from ..cell import Cell
|
|
|
|
|
from ..._checkparam import Rel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _PoolNd(Cell):
|
|
|
|
@ -263,10 +264,15 @@ class AvgPool1d(_PoolNd):
|
|
|
|
|
stride=1,
|
|
|
|
|
pad_mode="valid"):
|
|
|
|
|
super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode)
|
|
|
|
|
validator.check_type('kernel_size', kernel_size, [int,])
|
|
|
|
|
validator.check_type('stride', stride, [int,])
|
|
|
|
|
self.padding = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'])
|
|
|
|
|
if not isinstance(kernel_size, int):
|
|
|
|
|
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE)
|
|
|
|
|
raise ValueError("kernel_size should be 1 int number but got {}".
|
|
|
|
|
format(kernel_size))
|
|
|
|
|
if not isinstance(stride, int):
|
|
|
|
|
validator.check_integer("stride", stride, 1, Rel.GE)
|
|
|
|
|
raise ValueError("stride should be 1 int number but got {}".format(stride))
|
|
|
|
|
self.kernel_size = (1, kernel_size)
|
|
|
|
|
self.stride = (1, stride)
|
|
|
|
|