|
|
@ -128,6 +128,7 @@ class GeSwitch(PrimitiveWithInfer):
|
|
|
|
return (data, data)
|
|
|
|
return (data, data)
|
|
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, data_type, pred_type):
|
|
|
|
def infer_dtype(self, data_type, pred_type):
|
|
|
|
|
|
|
|
validator.check_subclass("data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
|
|
|
|
validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name)
|
|
|
|
validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name)
|
|
|
|
return (data_type, data_type)
|
|
|
|
return (data_type, data_type)
|
|
|
|
|
|
|
|
|
|
|
@ -153,7 +154,20 @@ class Merge(PrimitiveWithInfer):
|
|
|
|
raise NotImplementedError
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
def infer_shape(self, inputs):
|
|
|
|
def infer_shape(self, inputs):
|
|
|
|
|
|
|
|
validator.check_integer('inputs len', len(inputs), 0, Rel.GT, self.name)
|
|
|
|
|
|
|
|
input_0 = inputs[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(1, len(inputs)):
|
|
|
|
|
|
|
|
if inputs[i] != input_0:
|
|
|
|
|
|
|
|
raise ValueError(f"For \'{self.name}\', the shape of {i}th input should be same as "
|
|
|
|
|
|
|
|
f"first input {input_0}, but got {inputs[i]}.")
|
|
|
|
|
|
|
|
|
|
|
|
return (inputs[0], [1])
|
|
|
|
return (inputs[0], [1])
|
|
|
|
|
|
|
|
|
|
|
|
def infer_dtype(self, inputs):
|
|
|
|
def infer_dtype(self, inputs):
|
|
|
|
|
|
|
|
args = {}
|
|
|
|
|
|
|
|
for i, item in enumerate(inputs):
|
|
|
|
|
|
|
|
args['inputs[%d]' % i] = item
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
validator.check_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
|
|
|
return (inputs[0], mstype.int32)
|
|
|
|
return (inputs[0], mstype.int32)
|
|
|
|