diff --git a/mindspore/_extends/graph_kernel/model/op_infer.py b/mindspore/_extends/graph_kernel/model/op_infer.py index 8ae890364e..6738b29924 100644 --- a/mindspore/_extends/graph_kernel/model/op_infer.py +++ b/mindspore/_extends/graph_kernel/model/op_infer.py @@ -273,3 +273,14 @@ class Greater(_CompareOp): class GreaterEqual(_CompareOp): pass + + +class Select(_Elemwise): + def _check_type(self): + if self.inputs[0].dtype != "bool": + raise GKException("Select's input[0] should be a bool condition but got {}".format(self.inputs[0].dtype)) + if self.inputs[1].dtype != self.inputs[2].dtype: + raise GKException("Select's input mismatch ({} vs {})".format(self.inputs[1].dtype, self.inputs[2].dtype)) + + def _infer_type(self): + return self.inputs[1].dtype