support multiple var types for expand op, test=develop

revert-15774-anakin_subgraph_engine
jerrywgz 6 years ago
parent 381f2015a5
commit 8fc0fc314a

@ -146,7 +146,11 @@ REGISTER_OPERATOR(expand, ops::ExpandOp, ops::ExpandOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(expand_grad, ops::ExpandGradOp);
REGISTER_OP_CPU_KERNEL(
expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>);
expand, ops::ExpandKernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandKernel<paddle::platform::CPUDeviceContext, double>,
ops::ExpandKernel<paddle::platform::CPUDeviceContext, int>,
ops::ExpandKernel<paddle::platform::CPUDeviceContext, bool>);
REGISTER_OP_CPU_KERNEL(
expand_grad,
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, float>);
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ExpandGradKernel<paddle::platform::CPUDeviceContext, double>);

@ -15,7 +15,11 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
expand, ops::ExpandKernel<paddle::platform::CUDADeviceContext, float>);
expand, ops::ExpandKernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandKernel<paddle::platform::CUDADeviceContext, double>,
ops::ExpandKernel<paddle::platform::CUDADeviceContext, int>,
ops::ExpandKernel<paddle::platform::CUDADeviceContext, bool>);
REGISTER_OP_CUDA_KERNEL(
expand_grad,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, float>);
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ExpandGradKernel<paddle::platform::CUDADeviceContext, double>);

@ -109,5 +109,29 @@ class TestExpandOpRank4(OpTest):
self.check_grad(['X'], 'Out')
class TestExpandOpInteger(OpTest):
def setUp(self):
self.op_type = "expand"
self.inputs = {'X': np.random.random((2, 4, 5)).astype("int32")}
self.attrs = {'expand_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
class TestExpandOpBoolean(OpTest):
def setUp(self):
self.op_type = "expand"
self.inputs = {'X': np.random.random((2, 4, 5)).astype("bool")}
self.attrs = {'expand_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
if __name__ == "__main__":
unittest.main()

Loading…
Cancel
Save