diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index 6aa4c76b9c..44a2f37b66 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -146,7 +146,11 @@ REGISTER_OPERATOR(expand, ops::ExpandOp, ops::ExpandOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp); REGISTER_OP_CPU_KERNEL( - expand, ops::ExpandKernel); + expand, ops::ExpandKernel, + ops::ExpandKernel, + ops::ExpandKernel, + ops::ExpandKernel); REGISTER_OP_CPU_KERNEL( expand_grad, - ops::ExpandGradKernel); + ops::ExpandGradKernel, + ops::ExpandGradKernel); diff --git a/paddle/fluid/operators/expand_op.cu b/paddle/fluid/operators/expand_op.cu index d95c9b6180..50a506b294 100644 --- a/paddle/fluid/operators/expand_op.cu +++ b/paddle/fluid/operators/expand_op.cu @@ -15,7 +15,11 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - expand, ops::ExpandKernel); + expand, ops::ExpandKernel, + ops::ExpandKernel, + ops::ExpandKernel, + ops::ExpandKernel); REGISTER_OP_CUDA_KERNEL( expand_grad, - ops::ExpandGradKernel); + ops::ExpandGradKernel, + ops::ExpandGradKernel); diff --git a/python/paddle/fluid/tests/unittests/test_expand_op.py b/python/paddle/fluid/tests/unittests/test_expand_op.py index 67a8d8f072..218fc697f2 100644 --- a/python/paddle/fluid/tests/unittests/test_expand_op.py +++ b/python/paddle/fluid/tests/unittests/test_expand_op.py @@ -109,5 +109,29 @@ class TestExpandOpRank4(OpTest): self.check_grad(['X'], 'Out') +class TestExpandOpInteger(OpTest): + def setUp(self): + self.op_type = "expand" + self.inputs = {'X': np.random.random((2, 4, 5)).astype("int32")} + self.attrs = {'expand_times': [2, 1, 4]} + output = np.tile(self.inputs['X'], (2, 1, 4)) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + +class TestExpandOpBoolean(OpTest): + def setUp(self): + self.op_type = "expand" + self.inputs = {'X': np.random.random((2, 4, 5)).astype("bool")} + self.attrs = {'expand_times': [2, 1, 4]} + output = np.tile(self.inputs['X'], (2, 1, 4)) + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + if __name__ == "__main__": unittest.main()