|
|
|
@ -124,16 +124,9 @@ class MaxPool2d(_PoolNd):
|
|
|
|
|
strides=self.stride,
|
|
|
|
|
padding=self.pad_mode,
|
|
|
|
|
data_format=self.format)
|
|
|
|
|
self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size,
|
|
|
|
|
strides=self.stride,
|
|
|
|
|
padding=self.pad_mode)
|
|
|
|
|
self.is_tbe = context.get_context("device_target") == "Ascend"
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
if self.is_tbe and self.training:
|
|
|
|
|
out = self.max_pool_with_arg_max(x)[0]
|
|
|
|
|
else:
|
|
|
|
|
out = self.max_pool(x)
|
|
|
|
|
out = self.max_pool(x)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -198,22 +191,15 @@ class MaxPool1d(_PoolNd):
|
|
|
|
|
self.max_pool = P.MaxPool(ksize=self.kernel_size,
|
|
|
|
|
strides=self.stride,
|
|
|
|
|
padding=self.pad_mode)
|
|
|
|
|
self.max_pool_with_arg_max = P.MaxPoolWithArgmax(ksize=self.kernel_size,
|
|
|
|
|
strides=self.stride,
|
|
|
|
|
padding=self.pad_mode)
|
|
|
|
|
self.shape = F.shape
|
|
|
|
|
self.reduce_mean = P.ReduceMean(keep_dims=True)
|
|
|
|
|
self.expand = P.ExpandDims()
|
|
|
|
|
self.squeeze = P.Squeeze(2)
|
|
|
|
|
self.is_tbe = context.get_context("device_target") == "Ascend"
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
_shape_check(self.shape(x))
|
|
|
|
|
x = self.expand(x, 2)
|
|
|
|
|
if self.is_tbe and self.training:
|
|
|
|
|
output = self.max_pool_with_arg_max(x)[0]
|
|
|
|
|
else:
|
|
|
|
|
output = self.max_pool(x)
|
|
|
|
|
output = self.max_pool(x)
|
|
|
|
|
output = self.squeeze(output)
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|