|
|
|
|
@ -15,10 +15,12 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/operators/unsqueeze_op.h"
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
namespace plat = paddle::platform;
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
unsqueeze, ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
|
|
|
|
@ -26,6 +28,8 @@ REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
unsqueeze_grad,
|
|
|
|
|
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, float>,
|
|
|
|
|
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, double>,
|
|
|
|
|
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext,
|
|
|
|
|
plat::float16>,
|
|
|
|
|
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int>,
|
|
|
|
|
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int8_t>,
|
|
|
|
|
ops::UnsqueezeGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
|
|
|
|
@ -33,6 +37,7 @@ REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
unsqueeze2,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, float>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, double>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, plat::float16>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int8_t>,
|
|
|
|
|
ops::UnsqueezeKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
|
|
|
|
@ -40,6 +45,8 @@ REGISTER_OP_CUDA_KERNEL(
|
|
|
|
|
unsqueeze2_grad,
|
|
|
|
|
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, float>,
|
|
|
|
|
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, double>,
|
|
|
|
|
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext,
|
|
|
|
|
plat::float16>,
|
|
|
|
|
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int>,
|
|
|
|
|
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int8_t>,
|
|
|
|
|
ops::Unsqueeze2GradKernel<paddle::platform::CUDADeviceContext, int64_t>);
|
|
|
|
|
|