Support MetaTensor in Equal's infer_value

pull/8690/head
huanghui 4 years ago
parent 17acf2bcaa
commit a423320d58

@ -23,7 +23,7 @@ from .. import signature as sig
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
from ...common.tensor import Tensor
from ...common.tensor import Tensor, MetaTensor
from .._utils import get_broadcast_shape
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
@ -2324,9 +2324,13 @@ class Equal(_LogicBinaryOp):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name)
def infer_value(self, x, y):
if x is not None and y is not None:
return Tensor(x.asnumpy() == y.asnumpy())
return None
if x is None or y is None:
return None
if isinstance(x, MetaTensor):
x = x.to_tensor()
if isinstance(y, MetaTensor):
y = y.to_tensor()
return Tensor(x.asnumpy() == y.asnumpy())
class ApproximateEqual(_LogicBinaryOp):

Loading…
Cancel
Save