|
|
|
@ -99,20 +99,20 @@ class GemmConvKernel : public framework::OpKernel<T> {
|
|
|
|
|
// use col_shape in the im2col calculation
|
|
|
|
|
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
|
|
|
|
|
// o_h, o_w}
|
|
|
|
|
std::vector<int64_t> col_shape_vec(filter_shape_vec.size() +
|
|
|
|
|
output_shape_vec.size() - 3);
|
|
|
|
|
col_shape_vec.assign(1, input->dims()[1] / groups);
|
|
|
|
|
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin() + 2,
|
|
|
|
|
filter_shape_vec.end());
|
|
|
|
|
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin() + 2,
|
|
|
|
|
output_shape_vec.end());
|
|
|
|
|
size_t data_dim = filter_shape_vec.size() - 2;
|
|
|
|
|
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
|
|
|
|
|
col_shape_vec[0] = input->dims()[1] / groups;
|
|
|
|
|
for (size_t j = 0; j < data_dim; ++j) {
|
|
|
|
|
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
|
|
|
|
|
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
|
|
|
|
|
}
|
|
|
|
|
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
|
|
|
|
|
|
|
|
|
|
// use col_matrix_shape in the gemm calculation
|
|
|
|
|
// size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
|
|
|
|
|
// o_h * o_w)
|
|
|
|
|
framework::DDim col_matrix_shape =
|
|
|
|
|
framework::flatten_to_2d(col_shape, filter_shape_vec.size() - 2 + 1);
|
|
|
|
|
framework::flatten_to_2d(col_shape, data_dim + 1);
|
|
|
|
|
|
|
|
|
|
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
|
|
|
|
|
Tensor col;
|
|
|
|
@ -155,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel<T> {
|
|
|
|
|
col.ShareDataWith(in_slice);
|
|
|
|
|
col_matrix.ShareDataWith(col);
|
|
|
|
|
col_matrix.Resize(col_matrix_shape);
|
|
|
|
|
} else if (filter_shape_vec.size() == 4) {
|
|
|
|
|
} else if (data_dim == 2U) {
|
|
|
|
|
// im2col
|
|
|
|
|
im2col(context.device_context(), in_slice, dilations, strides,
|
|
|
|
|
std::vector<int>{paddings[0], paddings[1], paddings[0],
|
|
|
|
|
paddings[1]},
|
|
|
|
|
&col);
|
|
|
|
|
} else if (filter_shape_vec.size() == 5) {
|
|
|
|
|
} else if (data_dim == 3U) {
|
|
|
|
|
// vol2col
|
|
|
|
|
vol2col(context.device_context(), in_slice, dilations, strides,
|
|
|
|
|
paddings, &col);
|
|
|
|
@ -211,13 +211,13 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
// use col_shape in the im2col calculation
|
|
|
|
|
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
|
|
|
|
|
// o_h, o_w}
|
|
|
|
|
std::vector<int64_t> col_shape_vec(filter_shape_vec.size() +
|
|
|
|
|
output_shape_vec.size() - 3);
|
|
|
|
|
col_shape_vec.assign(1, input->dims()[1] / groups);
|
|
|
|
|
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin() + 2,
|
|
|
|
|
filter_shape_vec.end());
|
|
|
|
|
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin() + 2,
|
|
|
|
|
output_shape_vec.end());
|
|
|
|
|
size_t data_dim = filter_shape_vec.size() - 2;
|
|
|
|
|
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
|
|
|
|
|
col_shape_vec[0] = input->dims()[1] / groups;
|
|
|
|
|
for (size_t j = 0; j < data_dim; ++j) {
|
|
|
|
|
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
|
|
|
|
|
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
|
|
|
|
|
}
|
|
|
|
|
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
|
|
|
|
|
|
|
|
|
|
// use col_matrix_shape in the gemm calculation
|
|
|
|
@ -225,7 +225,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
// or
|
|
|
|
|
// (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
|
|
|
|
|
framework::DDim col_matrix_shape =
|
|
|
|
|
framework::flatten_to_2d(col_shape, filter_shape_vec.size() - 2 + 1);
|
|
|
|
|
framework::flatten_to_2d(col_shape, data_dim + 1);
|
|
|
|
|
|
|
|
|
|
framework::DDim input_shape = framework::slice_ddim(
|
|
|
|
|
input->dims(), 1, static_cast<int>(input->dims().size()));
|
|
|
|
@ -286,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
out_grad_slice, false, T(1.0), &col_matrix,
|
|
|
|
|
T(0.0));
|
|
|
|
|
|
|
|
|
|
if (is_expand && filter_shape_vec.size() == 4) {
|
|
|
|
|
if (is_expand && data_dim == 2U) {
|
|
|
|
|
col2im(context.device_context(), col, dilations, strides,
|
|
|
|
|
std::vector<int>{paddings[0], paddings[1], paddings[0],
|
|
|
|
|
paddings[1]},
|
|
|
|
|
&in_grad_slice);
|
|
|
|
|
} else if (is_expand && filter_shape_vec.size() == 5) {
|
|
|
|
|
} else if (is_expand && data_dim == 3U) {
|
|
|
|
|
col2vol(context.device_context(), col, dilations, strides, paddings,
|
|
|
|
|
&in_grad_slice);
|
|
|
|
|
}
|
|
|
|
@ -320,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
col.ShareDataWith(in_slice);
|
|
|
|
|
col_matrix.ShareDataWith(col);
|
|
|
|
|
col_matrix.Resize(col_matrix_shape);
|
|
|
|
|
} else if (filter_shape_vec.size() == 4) {
|
|
|
|
|
} else if (data_dim == 2U) {
|
|
|
|
|
im2col(context.device_context(), in_slice, dilations, strides,
|
|
|
|
|
std::vector<int>{paddings[0], paddings[1], paddings[0],
|
|
|
|
|
paddings[1]},
|
|
|
|
|
&col);
|
|
|
|
|
} else if (filter_shape_vec.size() == 5) {
|
|
|
|
|
} else if (data_dim == 3U) {
|
|
|
|
|
vol2col(context.device_context(), in_slice, dilations, strides,
|
|
|
|
|
paddings, &col);
|
|
|
|
|
}
|
|
|
|
|