fixed ScalarCast

pull/1652/head
jiangjinsheng 5 years ago
parent cbdc59e861
commit 0ac47f2f71

@ -15,6 +15,9 @@
"""inner_ops"""
import numbers
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common.dtype import tensor, dtype_to_pytype
from ..primitive import prim_attr_register, PrimitiveWithInfer
@ -40,8 +43,10 @@ class ScalarCast(PrimitiveWithInfer):
pass
def __infer__(self, x, t):
validator.check_integer('x shape', len(x['shape']), 0, Rel.EQ, self.name)
value, to = x['value'], t['value']
if value is not None:
validator.check_value_type("value", value, [numbers.Number, bool], self.name)
if isinstance(to, type(tensor)):
to = to.element_type()
np_type = dtype_to_pytype(to)

Loading…
Cancel
Save