fix cast check elim

pull/2067/head
BowenK 5 years ago
parent b096383386
commit 35a57e076d

@ -186,11 +186,13 @@ class Cast(PrimitiveWithInfer):
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
def check_elim(self, x, dtype):
if isinstance(x, Tensor):
if x.dtype == dtype:
if isinstance(x, (Tensor, numbers.Number)):
if isinstance(x, Tensor) and x.dtype == dtype:
return (True, x)
if isinstance(x, numbers.Number):
return (True, Tensor(x, dtype=dtype))
return (False, None)
raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs))
raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})")
def __infer__(self, x, t):
src_type = x['dtype']

Loading…
Cancel
Save