diff --git a/mindspore/ops/operations/control_ops.py b/mindspore/ops/operations/control_ops.py index 736855a046..f413fac0a9 100644 --- a/mindspore/ops/operations/control_ops.py +++ b/mindspore/ops/operations/control_ops.py @@ -69,6 +69,8 @@ class ControlDepend(Primitive): @prim_attr_register def __init__(self, depend_mode=0): """init""" + validator.check_int_range( + "depend_mode", depend_mode, 0, 1, Rel.INC_BOTH, self.name) def __call__(self, src, dst): return src @@ -128,8 +130,10 @@ class GeSwitch(PrimitiveWithInfer): return (data, data) def infer_dtype(self, data_type, pred_type): - validator.check_subclass("data", data_type, (mstype.tensor,) + mstype.number_type, self.name) - validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name) + validator.check_subclass( + "data", data_type, (mstype.tensor,) + mstype.number_type, self.name) + validator.check_tensor_type_same( + {"pred": pred_type}, [mstype.bool_], self.name) return (data_type, data_type) @@ -161,5 +165,6 @@ class Merge(PrimitiveWithInfer): for i, item in enumerate(inputs): args['inputs[%d]' % i] = item - validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name) + validator.check_tensor_type_same( + args, (mstype.bool_,) + mstype.number_type, self.name) return (inputs[0], mstype.int32) diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index 8d81811d9b..e77f6822b0 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """ test control ops """ +import pytest import numpy as np import mindspore as ms @@ -434,3 +435,11 @@ def test_index_to_switch_layer(): C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) + +def test_control_depend_check(): + with pytest.raises(TypeError) as e: + depend = P.ControlDepend(0.0) + with pytest.raises(ValueError) as e: + depend = P.ControlDepend(2) + with pytest.raises(TypeError) as e: + depend = P.ControlDepend((2,)) \ No newline at end of file