|
|
|
@ -421,15 +421,27 @@ REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker,
|
|
|
|
|
ops::MatMulOpGradMaker);
|
|
|
|
|
REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
matmul, ops::MatMulKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::MatMulKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|
ops::MatMulKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
|
paddle::platform::float16>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
matmul_grad,
|
|
|
|
|
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float>);
|
|
|
|
|
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext, double>,
|
|
|
|
|
ops::MatMulGradKernel<paddle::platform::CPUDeviceContext,
|
|
|
|
|
paddle::platform::float16>);
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float>);
|
|
|
|
|
matmul, ops::MatMulKernel<paddle::platform::CUDADeviceContext, float>,
|
|
|
|
|
ops::MatMulKernel<paddle::platform::CUDADeviceContext, double>,
|
|
|
|
|
ops::MatMulKernel<paddle::platform::CUDADeviceContext,
|
|
|
|
|
paddle::platform::float16>);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
matmul_grad,
|
|
|
|
|
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, float>);
|
|
|
|
|
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, float>,
|
|
|
|
|
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext, double>,
|
|
|
|
|
ops::MatMulGradKernel<paddle::platform::CUDADeviceContext,
|
|
|
|
|
paddle::platform::float16>);
|
|
|
|
|
#endif
|
|
|
|
|