|
|
|
@ -186,11 +186,13 @@ class Cast(PrimitiveWithInfer):
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
|
|
|
|
|
|
|
|
|
|
def check_elim(self, x, dtype):
|
|
|
|
|
if isinstance(x, Tensor):
|
|
|
|
|
if x.dtype == dtype:
|
|
|
|
|
if isinstance(x, (Tensor, numbers.Number)):
|
|
|
|
|
if isinstance(x, Tensor) and x.dtype == dtype:
|
|
|
|
|
return (True, x)
|
|
|
|
|
if isinstance(x, numbers.Number):
|
|
|
|
|
return (True, Tensor(x, dtype=dtype))
|
|
|
|
|
return (False, None)
|
|
|
|
|
raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs))
|
|
|
|
|
raise ValueError(f"Expecting (Tensor, dtype), got : ({x}, {dtype})")
|
|
|
|
|
|
|
|
|
|
def __infer__(self, x, t):
|
|
|
|
|
src_type = x['dtype']
|
|
|
|
|