|
|
|
@ -66,6 +66,8 @@ class _BatchNorm(Cell):
|
|
|
|
|
self.format = validator.check_string(data_format, ['NCDHW'], 'format', self.cls_name)
|
|
|
|
|
else:
|
|
|
|
|
self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
|
|
|
|
|
if context.get_context("device_target") != "Ascend" and self.format == "NCDHW":
|
|
|
|
|
raise ValueError("NCDHW format only support in Ascend target.")
|
|
|
|
|
if context.get_context("device_target") != "GPU" and self.format == "NHWC":
|
|
|
|
|
raise ValueError("NHWC format only support in GPU target.")
|
|
|
|
|
self.use_batch_statistics = use_batch_statistics
|
|
|
|
|