!3225 Fix input verification for input of GlobalbatchNorm.

Merge pull request !3225 from liuxiao93/fix-check-input-for-GlobalBatchNorm
pull/3225/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 5c81f9caa0

@ -45,7 +45,7 @@ class _BatchNorm(Cell):
moving_var_init='ones', moving_var_init='ones',
use_batch_statistics=None, use_batch_statistics=None,
device_num_each_group=1, device_num_each_group=1,
input_dims='1d'): input_dims='2d'):
super(_BatchNorm, self).__init__() super(_BatchNorm, self).__init__()
if num_features < 1: if num_features < 1:
raise ValueError("num_features must be at least 1") raise ValueError("num_features must be at least 1")
@ -151,6 +151,8 @@ class _BatchNorm(Cell):
_shape_check(self.shape(x)) _shape_check(self.shape(x))
if self.input_dims == '1d': if self.input_dims == '1d':
_shape_check_2d(self.shape(x)) _shape_check_2d(self.shape(x))
if self.input_dims == 'both':
_shape_check_2d_or_4d(self.shape(x))
if self.use_batch_statistics is None: if self.use_batch_statistics is None:
flag = self.training flag = self.training
else: else:
@ -211,7 +213,13 @@ def _shape_check_2d(input_shape):
@constexpr @constexpr
def _shape_check(in_shape): def _shape_check(in_shape):
if len(in_shape) != 4: if len(in_shape) != 4:
raise ValueError("The input must has 4 dims") raise ValueError("The input must has 4 dims.")
@constexpr
def _shape_check_2d_or_4d(in_shape):
if len(in_shape) != 2 and len(in_shape) != 4:
raise ValueError("The input must has 2 dims or 4 dims.")
@constexpr @constexpr
@ -449,7 +457,8 @@ class GlobalBatchNorm(_BatchNorm):
moving_mean_init, moving_mean_init,
moving_var_init, moving_var_init,
use_batch_statistics, use_batch_statistics,
device_num_each_group) device_num_each_group,
input_dims='both')
self.group = check_int_positive(device_num_each_group) self.group = check_int_positive(device_num_each_group)
if self.group <= 1: if self.group <= 1:
raise ValueError("the number of group must be greater than 1.") raise ValueError("the number of group must be greater than 1.")

Loading…
Cancel
Save