From a6959c2a135587f9280e3cd1ee5c44d2b8ca7f90 Mon Sep 17 00:00:00 2001 From: caifubi Date: Tue, 2 Mar 2021 14:16:45 +0800 Subject: [PATCH] fix bn cast bug --- mindspore/ccsrc/pybind_api/ir/primitive_py.cc | 2 +- mindspore/ops/operations/nn_ops.py | 10 ++++++++++ tests/st/mix_precision/test_mix_precision.py | 2 +- tests/ut/python/parallel/test_bn_prelu_cell.py | 4 ++-- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index 23cfee8d67..36162e65e7 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -75,7 +75,7 @@ PrimitivePy::~PrimitivePy() { void PrimitivePy::SetPyObj(const py::object &obj) { python_obj_ = obj; } void PrimitivePy::set_signatures(const std::vector &signatures) { signatures_ = signatures; - set_has_signature(true); + set_has_signature(!signatures.empty()); } py::function PrimitivePy::GetBpropFunction() { diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index d8b1a1c7b3..0835f6bf07 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1303,8 +1303,18 @@ class BatchNorm(PrimitiveWithInfer): [ 1.00000000e+00, 1.00000000e+00])) """ + __mindspore_signature__ = ( + sig.make_sig('input_x', dtype=sig.sig_dtype.T1), + sig.make_sig('scale', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T2), + sig.make_sig('bias', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T2), + sig.make_sig('mean', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T3), + sig.make_sig('variance', sig.sig_rw.RW_WRITE, dtype=sig.sig_dtype.T3) + ) + @prim_attr_register def __init__(self, is_training=False, epsilon=1e-5, momentum=0.1, data_format="NCHW"): + if is_training is False: + self.set_signatures(tuple()) validator.check_value_type('is_training', is_training, (bool,), self.name) validator.check_float_range(epsilon, 0, 1, Rel.INC_RIGHT, 'epsilon', self.name) validator.check_float_range(momentum, 0, 1, Rel.INC_BOTH, 'momentum', self.name) diff --git a/tests/st/mix_precision/test_mix_precision.py b/tests/st/mix_precision/test_mix_precision.py index 01aaee19e5..5ba8f83725 100644 --- a/tests/st/mix_precision/test_mix_precision.py +++ b/tests/st/mix_precision/test_mix_precision.py @@ -129,7 +129,7 @@ def test_sit_auto_mix_precision_model_o0(): model.train(1, dataset1, dataset_sink_mode=False) contend = read_validateir_file('./test_amp_o0') castnum = re.findall("Cast", contend) - assert len(castnum) == 17 + assert len(castnum) == 5 model.predict(Tensor(input_data)) contend = read_validateir_file('./test_amp_o0') castnum = re.findall("Cast", contend) diff --git a/tests/ut/python/parallel/test_bn_prelu_cell.py b/tests/ut/python/parallel/test_bn_prelu_cell.py index d94a79deb6..64fcd4538e 100644 --- a/tests/ut/python/parallel/test_bn_prelu_cell.py +++ b/tests/ut/python/parallel/test_bn_prelu_cell.py @@ -109,8 +109,8 @@ class FusedBatchNorm(nn.Cell): self.bn_train(x, self.gamma, self.beta, - None, - None) + self.moving_mean, + self.moving_variance) mean_sub = self.sub_mean(self.moving_mean, batch_mean) temp_mean = self.mul_mean(mean_sub, self.momentum)