|
|
|
@ -1303,8 +1303,18 @@ class BatchNorm(PrimitiveWithInfer):
|
|
|
|
|
[ 1.00000000e+00, 1.00000000e+00]))
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
__mindspore_signature__ = (
|
|
|
|
|
sig.make_sig('input_x', dtype=sig.sig_dtype.T1),
|
|
|
|
|
sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T2),
|
|
|
|
|
sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T2),
|
|
|
|
|
sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T3),
|
|
|
|
|
sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T3)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, is_training=False, epsilon=1e-5, momentum=0.1, data_format="NCHW"):
|
|
|
|
|
if is_training is False:
|
|
|
|
|
self.set_signatures(tuple())
|
|
|
|
|
validator.check_value_type('is_training', is_training, (bool,), self.name)
|
|
|
|
|
validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name)
|
|
|
|
|
validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name)
|
|
|
|
|