|
|
|
@ -13,13 +13,13 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
"""operator dsl function: equal"""
|
|
|
|
|
import akg.tvm
|
|
|
|
|
import akg.topi
|
|
|
|
|
from akg.utils.dsl_create import produce_shapes
|
|
|
|
|
from akg.utils import validation_check as vc_util
|
|
|
|
|
import _akg.tvm
|
|
|
|
|
import _akg.topi
|
|
|
|
|
from _akg.utils.dsl_create import produce_shapes
|
|
|
|
|
from _akg.utils import validation_check as vc_util
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@vc_util.check_input_type(akg.tvm.tensor.Tensor, akg.tvm.tensor.Tensor)
|
|
|
|
|
@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor)
|
|
|
|
|
def equal(input1, input2):
|
|
|
|
|
"""
|
|
|
|
|
check whether input1 equals to input2.
|
|
|
|
@ -42,13 +42,13 @@ def equal(input1, input2):
|
|
|
|
|
dtype = input1.dtype
|
|
|
|
|
|
|
|
|
|
# get equal compute
|
|
|
|
|
t_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(1, dtype), "T")
|
|
|
|
|
f_value = akg.tvm.compute(shape, lambda *indice: akg.tvm.const(0, dtype), "F")
|
|
|
|
|
|
|
|
|
|
input1_bro = akg.topi.broadcast_to(input1, shape)
|
|
|
|
|
input2_bro = akg.topi.broadcast_to(input2, shape)
|
|
|
|
|
c_out = akg.tvm.compute(shape, lambda *indice: akg.tvm.expr.Select(input1_bro[indice] == input2_bro[indice],
|
|
|
|
|
t_value[indice], f_value[indice]), name="C")
|
|
|
|
|
res = akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res")
|
|
|
|
|
t_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(1, dtype), "T")
|
|
|
|
|
f_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(0, dtype), "F")
|
|
|
|
|
|
|
|
|
|
input1_bro = _akg.topi.broadcast_to(input1, shape)
|
|
|
|
|
input2_bro = _akg.topi.broadcast_to(input2, shape)
|
|
|
|
|
c_out = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.expr.Select(input1_bro[indice] == input2_bro[indice],
|
|
|
|
|
t_value[indice], f_value[indice]), name="C")
|
|
|
|
|
res = _akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res")
|
|
|
|
|
|
|
|
|
|
return res
|