|
|
|
@ -19,11 +19,13 @@ import numpy as np
|
|
|
|
|
from op_test import OpTest
|
|
|
|
|
import paddle.fluid.core as core
|
|
|
|
|
from paddle.fluid.op import Operator
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
from paddle.fluid import Program, program_guard
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestWhereOp(OpTest):
|
|
|
|
|
class TestWhereIndexOp(OpTest):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "where"
|
|
|
|
|
self.op_type = "where_index"
|
|
|
|
|
self.init_config()
|
|
|
|
|
|
|
|
|
|
def test_check_output(self):
|
|
|
|
@ -37,7 +39,7 @@ class TestWhereOp(OpTest):
|
|
|
|
|
|
|
|
|
|
class TestAllFalse(unittest.TestCase):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.op_type = "where"
|
|
|
|
|
self.op_type = "where_index"
|
|
|
|
|
self.init_config()
|
|
|
|
|
|
|
|
|
|
def check_with_place(self, place):
|
|
|
|
@ -48,7 +50,7 @@ class TestAllFalse(unittest.TestCase):
|
|
|
|
|
out = scope.var("Out").get_tensor()
|
|
|
|
|
out.set(np.full(self.shape, 0).astype('int64'), place)
|
|
|
|
|
|
|
|
|
|
op = Operator("where", Condition="Condition", Out="Out")
|
|
|
|
|
op = Operator("where_index", Condition="Condition", Out="Out")
|
|
|
|
|
op.run(scope, place)
|
|
|
|
|
|
|
|
|
|
out_array = np.array(out)
|
|
|
|
@ -66,14 +68,14 @@ class TestAllFalse(unittest.TestCase):
|
|
|
|
|
self.check_with_place(core.CUDAPlace(0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestRank2(TestWhereOp):
|
|
|
|
|
class TestRank2(TestWhereIndexOp):
|
|
|
|
|
def init_config(self):
|
|
|
|
|
self.inputs = {'Condition': np.array([[True, False], [False, True]]), }
|
|
|
|
|
|
|
|
|
|
self.outputs = {'Out': np.array([[0, 0], [1, 1]], dtype='int64')}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestRank3(TestWhereOp):
|
|
|
|
|
class TestRank3(TestWhereIndexOp):
|
|
|
|
|
def init_config(self):
|
|
|
|
|
self.inputs = {
|
|
|
|
|
'Condition': np.array([[[True, False], [False, True]],
|
|
|
|
@ -88,5 +90,17 @@ class TestRank3(TestWhereOp):
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestWhereOpError(unittest.TestCase):
|
|
|
|
|
def test_api(self):
|
|
|
|
|
with program_guard(Program(), Program()):
|
|
|
|
|
cond = fluid.layers.data(name='cond', shape=[4], dtype='bool')
|
|
|
|
|
result = fluid.layers.where(cond)
|
|
|
|
|
|
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
|
|
|
exe.run(fluid.default_startup_program())
|
|
|
|
|
cond_i = np.array([True, False, False, False]).astype("bool")
|
|
|
|
|
out = exe.run(fluid.default_main_program(), feed={'cond': cond_i})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
unittest.main()
|