|
|
|
@ -506,7 +506,7 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
// dw = ddx * dy ==> dw(Cout, Cin, kh, kw), ddx(N, Cin, H, W), dy(N, Cout,
|
|
|
|
|
// oH, oW)
|
|
|
|
|
// dw convolution double grad: im2col(vol2col) + gemm
|
|
|
|
|
if (dW) {
|
|
|
|
|
if (dW && ddX) {
|
|
|
|
|
dW->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
set_zero(dev_ctx, dW, static_cast<T>(0));
|
|
|
|
|
Tensor dW_arr = *dW;
|
|
|
|
@ -549,36 +549,38 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
|
|
|
|
|
math::Vol2ColFunctor<DeviceContext, T> vol2col;
|
|
|
|
|
for (int i = 0; i < batch_size; ++i) {
|
|
|
|
|
Tensor ddx_batch = ddX->Slice(i, i + 1).Resize(input_shape);
|
|
|
|
|
Tensor x_batch = X->Slice(i, i + 1).Resize(input_shape);
|
|
|
|
|
Tensor ddy_batch = ddY->Slice(i, i + 1).Resize(output_matrix_shape);
|
|
|
|
|
for (int g = 0; g < groups; ++g) {
|
|
|
|
|
Tensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step);
|
|
|
|
|
Tensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step);
|
|
|
|
|
if (!is_expand) {
|
|
|
|
|
col.ShareDataWith(ddx_slice);
|
|
|
|
|
col_matrix.ShareDataWith(col);
|
|
|
|
|
col_matrix.Resize(col_matrix_shape);
|
|
|
|
|
} else if (data_dim == 2U) {
|
|
|
|
|
// im2col
|
|
|
|
|
im2col(dev_ctx, ddx_slice, dilations, strides,
|
|
|
|
|
std::vector<int>{paddings[0], paddings[1], paddings[0],
|
|
|
|
|
paddings[1]},
|
|
|
|
|
&col);
|
|
|
|
|
} else if (data_dim == 3U) {
|
|
|
|
|
// vol2col
|
|
|
|
|
vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// gemm
|
|
|
|
|
Tensor ddy_slice = ddy_batch.Slice(g * out_step, (g + 1) * out_step);
|
|
|
|
|
Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step);
|
|
|
|
|
blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice,
|
|
|
|
|
T(0.0));
|
|
|
|
|
if (ddX) {
|
|
|
|
|
Tensor ddx_batch = ddX->Slice(i, i + 1).Resize(input_shape);
|
|
|
|
|
Tensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step);
|
|
|
|
|
if (!is_expand) {
|
|
|
|
|
col.ShareDataWith(ddx_slice);
|
|
|
|
|
col_matrix.ShareDataWith(col);
|
|
|
|
|
col_matrix.Resize(col_matrix_shape);
|
|
|
|
|
} else if (data_dim == 2U) {
|
|
|
|
|
// im2col
|
|
|
|
|
im2col(dev_ctx, ddx_slice, dilations, strides,
|
|
|
|
|
std::vector<int>{paddings[0], paddings[1], paddings[0],
|
|
|
|
|
paddings[1]},
|
|
|
|
|
&col);
|
|
|
|
|
} else if (data_dim == 3U) {
|
|
|
|
|
// vol2col
|
|
|
|
|
vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// gemm
|
|
|
|
|
Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step);
|
|
|
|
|
blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice,
|
|
|
|
|
T(0.0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ddW_in) {
|
|
|
|
|
Tensor ddW;
|
|
|
|
|
ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape);
|
|
|
|
|
Tensor x_batch = X->Slice(i, i + 1).Resize(input_shape);
|
|
|
|
|
Tensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step);
|
|
|
|
|
|
|
|
|
|
if (!is_expand) {
|
|
|
|
|
col.ShareDataWith(x_slice);
|
|
|
|
|