|
|
|
@ -609,6 +609,8 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto dX = ctx.Output<Tensor>("DInput");
|
|
|
|
|
if (ddO) {
|
|
|
|
|
ddO->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
|
|
|
|
|
set_zero(dev_ctx, ddO, static_cast<T>(0));
|
|
|
|
|
}
|
|
|
|
|
if (dW) {
|
|
|
|
|
dW->mutable_data<T>(ctx.GetPlace());
|
|
|
|
@ -646,7 +648,7 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
// transform Tensors to channel first-----------
|
|
|
|
|
Tensor transformed_X_channel(X->type());
|
|
|
|
|
Tensor transformed_dO_channel(dO->type());
|
|
|
|
|
Tensor transformed_ddX_channel(ddX->type());
|
|
|
|
|
Tensor transformed_ddX_channel(X->type());
|
|
|
|
|
|
|
|
|
|
Tensor transformed_ddO_channel(dO->type());
|
|
|
|
|
Tensor transformed_dX_channel(X->type());
|
|
|
|
@ -662,10 +664,12 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
TransToChannelFirst<platform::CUDADeviceContext, T>(
|
|
|
|
|
ctx, dO, &transformed_dO_channel);
|
|
|
|
|
|
|
|
|
|
ResizeToChannelFirst<platform::CUDADeviceContext, T>(
|
|
|
|
|
ctx, ddX, &transformed_ddX_channel);
|
|
|
|
|
TransToChannelFirst<platform::CUDADeviceContext, T>(
|
|
|
|
|
ctx, ddX, &transformed_ddX_channel);
|
|
|
|
|
if (ddX) {
|
|
|
|
|
ResizeToChannelFirst<platform::CUDADeviceContext, T>(
|
|
|
|
|
ctx, ddX, &transformed_ddX_channel);
|
|
|
|
|
TransToChannelFirst<platform::CUDADeviceContext, T>(
|
|
|
|
|
ctx, ddX, &transformed_ddX_channel);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ddO) {
|
|
|
|
|
ResizeToChannelFirst<platform::CUDADeviceContext, T>(
|
|
|
|
@ -680,7 +684,9 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
} else {
|
|
|
|
|
transformed_X_channel = *X;
|
|
|
|
|
transformed_dO_channel = *dO;
|
|
|
|
|
transformed_ddX_channel = *ddX;
|
|
|
|
|
if (ddX) {
|
|
|
|
|
transformed_ddX_channel = *ddX;
|
|
|
|
|
}
|
|
|
|
|
if (ddO) {
|
|
|
|
|
transformed_ddO_channel.ShareDataWith(*ddO);
|
|
|
|
|
}
|
|
|
|
@ -729,15 +735,15 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
transformed_X.Resize(new_input_shape);
|
|
|
|
|
transformed_ddX.Resize(new_input_shape);
|
|
|
|
|
transformed_dX.Resize(new_input_shape);
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
ctx.template device_context<paddle::platform::CUDADeviceContext>();
|
|
|
|
|
|
|
|
|
|
transformed_X =
|
|
|
|
|
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
|
|
|
|
|
new_input_shape, dev_ctx);
|
|
|
|
|
transformed_ddX =
|
|
|
|
|
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
|
|
|
|
|
new_input_shape, dev_ctx);
|
|
|
|
|
if (ddX) {
|
|
|
|
|
transformed_ddX =
|
|
|
|
|
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
|
|
|
|
|
new_input_shape, dev_ctx);
|
|
|
|
|
}
|
|
|
|
|
if (dX) {
|
|
|
|
|
transformed_dX =
|
|
|
|
|
ctx.AllocateTmpTensor<T, paddle::platform::CUDADeviceContext>(
|
|
|
|
@ -751,16 +757,20 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
case 4: {
|
|
|
|
|
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
|
|
|
|
|
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
|
|
|
|
|
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
|
|
|
|
|
ctx, input_pad, transformed_ddX_channel, pad_value,
|
|
|
|
|
&transformed_ddX);
|
|
|
|
|
if (ddX) {
|
|
|
|
|
math::PadFunction<paddle::platform::CUDADeviceContext, T, 4>(
|
|
|
|
|
ctx, input_pad, transformed_ddX_channel, pad_value,
|
|
|
|
|
&transformed_ddX);
|
|
|
|
|
}
|
|
|
|
|
} break;
|
|
|
|
|
case 5: {
|
|
|
|
|
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
|
|
|
|
|
ctx, input_pad, transformed_X_channel, pad_value, &transformed_X);
|
|
|
|
|
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
|
|
|
|
|
ctx, input_pad, transformed_ddX_channel, pad_value,
|
|
|
|
|
&transformed_ddX);
|
|
|
|
|
if (ddX) {
|
|
|
|
|
math::PadFunction<paddle::platform::CUDADeviceContext, T, 5>(
|
|
|
|
|
ctx, input_pad, transformed_ddX_channel, pad_value,
|
|
|
|
|
&transformed_ddX);
|
|
|
|
|
}
|
|
|
|
|
} break;
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW("ConvOp only support tensors with 4 or 5 dimensions.");
|
|
|
|
@ -768,7 +778,9 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
transformed_X.ShareDataWith(transformed_X_channel);
|
|
|
|
|
transformed_ddX.ShareDataWith(transformed_ddX_channel);
|
|
|
|
|
if (ddX) {
|
|
|
|
|
transformed_ddX.ShareDataWith(transformed_ddX_channel);
|
|
|
|
|
}
|
|
|
|
|
if (dX) {
|
|
|
|
|
transformed_dX.ShareDataWith(transformed_dX_channel);
|
|
|
|
|
}
|
|
|
|
@ -936,10 +948,9 @@ class CUDNNConvDoubleGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
ctx, &transformed_ddO_channel, ddO);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
T* transformed_dy_channel = nullptr;
|
|
|
|
|
T* transformed_dy_channel = transformed_dO_channel.data<T>();
|
|
|
|
|
if (dW && ddX) {
|
|
|
|
|
ddx = transformed_ddX.data<T>();
|
|
|
|
|
transformed_dy_channel = transformed_dO_channel.data<T>();
|
|
|
|
|
for (int i = 0; i < groups; i++) {
|
|
|
|
|
wkspace_handle.RunFunc(
|
|
|
|
|
[&](void* workspace_ptr) {
|
|
|
|
|