add OpInfer for op Select

pull/12902/head
dayschan 4 years ago
parent 5135c214b7
commit 454500309c

@ -273,3 +273,14 @@ class Greater(_CompareOp):
class GreaterEqual(_CompareOp):
pass
class Select(_Elemwise):
def _check_type(self):
if self.inputs[0].dtype != "bool":
raise GKException("Select's input[0] should be a bool condition but got {}".format(self.inputs[0].dtype))
if self.inputs[1].dtype != self.inputs[2].dtype:
raise GKException("Select's input mismatch ({} vs {})".format(self.inputs[1].dtype, self.inputs[2].dtype))
def _infer_type(self):
return self.inputs[1].dtype

Loading…
Cancel
Save