|
|
|
@ -184,6 +184,13 @@ class Cast(PrimitiveWithInfer):
|
|
|
|
|
"""init Cast"""
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'dst_type'], outputs=['output'])
|
|
|
|
|
|
|
|
|
|
def check_elim(self, x, dtype):
|
|
|
|
|
if isinstance(x, Tensor):
|
|
|
|
|
if x.dtype() == dtype:
|
|
|
|
|
return (True, x)
|
|
|
|
|
return (False, None)
|
|
|
|
|
raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs))
|
|
|
|
|
|
|
|
|
|
def __infer__(self, x, t):
|
|
|
|
|
src_type = x['dtype']
|
|
|
|
|
dst_type = t['value']
|
|
|
|
@ -1310,6 +1317,15 @@ class Tile(PrimitiveWithInfer):
|
|
|
|
|
"""init Tile"""
|
|
|
|
|
self.init_prim_io_names(inputs=['x', 'multiples'], outputs=['output'])
|
|
|
|
|
|
|
|
|
|
def check_elim(self, base_tensor, multiplier):
|
|
|
|
|
if (not isinstance(base_tensor, Tensor)) or (not isinstance(multiplier, tuple)):
|
|
|
|
|
raise ValueError("Expecting (Tensor, tuple), got: ({}, {})".format(base_tensor, multiplier))
|
|
|
|
|
def is_all_zeros(v_tuple):
|
|
|
|
|
return all(v == 1 for v in v_tuple)
|
|
|
|
|
if is_all_zeros(multiplier):
|
|
|
|
|
return (True, base_tensor)
|
|
|
|
|
return (False, None)
|
|
|
|
|
|
|
|
|
|
def __infer__(self, x, multiples):
|
|
|
|
|
multiples_v = multiples['value']
|
|
|
|
|
x_shp = x['shape']
|
|
|
|
|