|
|
|
@ -26,8 +26,12 @@ template <typename T>
|
|
|
|
|
class CUDNNAffineGridOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
|
|
|
|
"It must use CUDAPlace.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
platform::is_gpu_place(ctx.GetPlace()), true,
|
|
|
|
|
platform::errors::InvalidArgument("Only "
|
|
|
|
|
"support for CUDAPlace.Please switch "
|
|
|
|
|
"your context from CPUPlace to "
|
|
|
|
|
"CUDAPlace or update your cudnn."));
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
auto handle = dev_ctx.cudnn_handle();
|
|
|
|
|
auto* theta = ctx.Input<Tensor>("Theta");
|
|
|
|
@ -56,8 +60,11 @@ class CUDNNAffineGridOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
cudnnSpatialTransformerDescriptor_t cudnn_st_desc =
|
|
|
|
|
st_desc.descriptor<T>(4, h_size_data);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cudnnSpatialTfGridGeneratorForward(
|
|
|
|
|
handle, cudnn_st_desc, theta_data, output_data));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
platform::dynload::cudnnSpatialTfGridGeneratorForward(
|
|
|
|
|
handle, cudnn_st_desc, theta_data, output_data),
|
|
|
|
|
0, platform::errors::Fatal("Some errors has occurred "
|
|
|
|
|
"during forward computation in cudnn."));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -65,8 +72,12 @@ template <typename T>
|
|
|
|
|
class CUDNNAffineGridGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
|
|
|
|
"It must use CUDAPlace.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Only "
|
|
|
|
|
"support for CUDAPlace. Please switch "
|
|
|
|
|
"your context from CPUPlace to "
|
|
|
|
|
"CUDAPlace or update your cudnn."));
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
auto handle = dev_ctx.cudnn_handle();
|
|
|
|
|
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
|
|
|
|
@ -95,8 +106,12 @@ class CUDNNAffineGridGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
const T* output_grad_data = output_grad->data<T>();
|
|
|
|
|
T* theta_grad_data = theta_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::dynload::cudnnSpatialTfGridGeneratorBackward(
|
|
|
|
|
handle, cudnn_st_desc, output_grad_data, theta_grad_data));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
platform::dynload::cudnnSpatialTfGridGeneratorBackward(
|
|
|
|
|
handle, cudnn_st_desc, output_grad_data, theta_grad_data),
|
|
|
|
|
0,
|
|
|
|
|
"Some errors "
|
|
|
|
|
"has occurred during forward computation in cudnn;");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|