|
|
|
@ -15,6 +15,9 @@
|
|
|
|
|
|
|
|
|
|
"""inner_ops"""
|
|
|
|
|
|
|
|
|
|
import numbers
|
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
|
from ..._checkparam import Rel
|
|
|
|
|
from ...common.dtype import tensor, dtype_to_pytype
|
|
|
|
|
from ..primitive import prim_attr_register, PrimitiveWithInfer
|
|
|
|
|
|
|
|
|
@ -40,8 +43,10 @@ class ScalarCast(PrimitiveWithInfer):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def __infer__(self, x, t):
|
|
|
|
|
validator.check_integer('x shape', len(x['shape']), 0, Rel.EQ, self.name)
|
|
|
|
|
value, to = x['value'], t['value']
|
|
|
|
|
if value is not None:
|
|
|
|
|
validator.check_value_type("value", value, [numbers.Number, bool], self.name)
|
|
|
|
|
if isinstance(to, type(tensor)):
|
|
|
|
|
to = to.element_type()
|
|
|
|
|
np_type = dtype_to_pytype(to)
|
|
|
|
|