|
|
|
@ -69,6 +69,8 @@ class ControlDepend(Primitive):
|
|
|
|
|
@prim_attr_register
|
|
|
|
|
def __init__(self, depend_mode=0):
|
|
|
|
|
"""init"""
|
|
|
|
|
validator.check_int_range(
|
|
|
|
|
"depend_mode", depend_mode, 0, 1, Rel.INC_BOTH, self.name)
|
|
|
|
|
|
|
|
|
|
def __call__(self, src, dst):
|
|
|
|
|
return src
|
|
|
|
@ -128,8 +130,10 @@ class GeSwitch(PrimitiveWithInfer):
|
|
|
|
|
return (data, data)
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, data_type, pred_type):
|
|
|
|
|
validator.check_subclass("data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
|
|
|
|
|
validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name)
|
|
|
|
|
validator.check_subclass(
|
|
|
|
|
"data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
|
|
|
|
|
validator.check_tensor_type_same(
|
|
|
|
|
{"pred": pred_type}, [mstype.bool_], self.name)
|
|
|
|
|
return (data_type, data_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -161,5 +165,6 @@ class Merge(PrimitiveWithInfer):
|
|
|
|
|
for i, item in enumerate(inputs):
|
|
|
|
|
args['inputs[%d]' % i] = item
|
|
|
|
|
|
|
|
|
|
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
|
|
|
|
validator.check_tensor_type_same(
|
|
|
|
|
args, (mstype.bool_,) + mstype.number_type, self.name)
|
|
|
|
|
return (inputs[0], mstype.int32)
|
|
|
|
|