|
|
|
@ -23,7 +23,7 @@ from .. import signature as sig
|
|
|
|
|
from ..._checkparam import Validator as validator
|
|
|
|
|
from ..._checkparam import Rel
|
|
|
|
|
from ...common import dtype as mstype
|
|
|
|
|
from ...common.tensor import Tensor
|
|
|
|
|
from ...common.tensor import Tensor, MetaTensor
|
|
|
|
|
from .._utils import get_broadcast_shape
|
|
|
|
|
from ..primitive import PrimitiveWithInfer, PrimitiveWithCheck, prim_attr_register, _run_op
|
|
|
|
|
|
|
|
|
@ -2324,9 +2324,13 @@ class Equal(_LogicBinaryOp):
|
|
|
|
|
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, mstype.number_type + (mstype.bool_,), self.name)
|
|
|
|
|
|
|
|
|
|
def infer_value(self, x, y):
|
|
|
|
|
if x is not None and y is not None:
|
|
|
|
|
return Tensor(x.asnumpy() == y.asnumpy())
|
|
|
|
|
return None
|
|
|
|
|
if x is None or y is None:
|
|
|
|
|
return None
|
|
|
|
|
if isinstance(x, MetaTensor):
|
|
|
|
|
x = x.to_tensor()
|
|
|
|
|
if isinstance(y, MetaTensor):
|
|
|
|
|
y = y.to_tensor()
|
|
|
|
|
return Tensor(x.asnumpy() == y.asnumpy())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ApproximateEqual(_LogicBinaryOp):
|
|
|
|
|