|
|
|
@ -84,7 +84,7 @@ class GeSwitch(PrimitiveWithInfer):
|
|
|
|
|
the true branch will be activated, or vise verse.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **data** (Tensor) - The data to be used for switch control.
|
|
|
|
|
- **data** (Union[Tensor, Number]) - The data to be used for switch control.
|
|
|
|
|
- **pred** (Tensor) - It should be a scalar whose type is bool and shape is `()`, It is used as condition for
|
|
|
|
|
switch control.
|
|
|
|
|
Outputs:
|
|
|
|
@ -144,7 +144,7 @@ class Merge(PrimitiveWithInfer):
|
|
|
|
|
One and only one of the inputs should be selected as the output
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **inputs** (Tuple) - The data to be merged.
|
|
|
|
|
- **inputs** (Tuple) - The data to be merged. All tuple elements should have same data type.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
tuple. Output is tuple(`data`, `output_index`). The `data` has the same shape of `inputs` element.
|
|
|
|
@ -171,6 +171,5 @@ class Merge(PrimitiveWithInfer):
|
|
|
|
|
for i, item in enumerate(inputs):
|
|
|
|
|
args['inputs[%d]' % i] = item
|
|
|
|
|
|
|
|
|
|
validator.check_tensor_type_same(
|
|
|
|
|
args, (mstype.bool_,) + mstype.number_type, self.name)
|
|
|
|
|
validator.check_scalar_or_tensor_type_same(args, (mstype.bool_,) + mstype.number_type, self.name)
|
|
|
|
|
return (inputs[0], mstype.int32)
|
|
|
|
|