|
|
|
|
@ -49,10 +49,8 @@ class TestPythonOperatorOverride(unittest.TestCase):
|
|
|
|
|
# compare func to check
|
|
|
|
|
compare_fns = [
|
|
|
|
|
lambda _a, _b: _a == _b,
|
|
|
|
|
lambda _a, _b: _a == _b,
|
|
|
|
|
lambda _a, _b: _a < _b,
|
|
|
|
|
lambda _a, _b: _a < _b,
|
|
|
|
|
lambda _a, _b: _a <= _b,
|
|
|
|
|
lambda _a, _b: _a > _b,
|
|
|
|
|
lambda _a, _b: _a <= _b,
|
|
|
|
|
lambda _a, _b: _a >= _b,
|
|
|
|
|
]
|
|
|
|
|
@ -69,7 +67,7 @@ class TestPythonOperatorOverride(unittest.TestCase):
|
|
|
|
|
for dtype in dtypes:
|
|
|
|
|
for compare_fn in compare_fns:
|
|
|
|
|
with framework.program_guard(framework.Program(),
|
|
|
|
|
gframework.Program()):
|
|
|
|
|
framework.Program()):
|
|
|
|
|
self.check_result(compare_fn, place, dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|