|
|
|
@ -38,7 +38,7 @@ inline bool IsExpand(std::vector<int64_t>& filter_dim,
|
|
|
|
|
std::vector<int>& dilations) {
|
|
|
|
|
bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true;
|
|
|
|
|
for (size_t j = 0; j < strides.size(); ++j) {
|
|
|
|
|
filter_1 = filter_1 && (static_cast<int>(filter_dim[j]) == 1);
|
|
|
|
|
filter_1 = filter_1 && (static_cast<int>(filter_dim[j + 2]) == 1);
|
|
|
|
|
strides_1 = strides_1 && (strides[j] == 1);
|
|
|
|
|
padding_0 = padding_0 && (paddings[j] == 0);
|
|
|
|
|
dilation_1 = dilation_1 && (dilations[j] == 1);
|
|
|
|
@ -91,32 +91,28 @@ class GemmConvKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
const int batch_size = static_cast<int>(input->dims()[0]);
|
|
|
|
|
|
|
|
|
|
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
|
|
|
|
|
// filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
|
|
|
|
|
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
|
|
|
|
|
filter_shape_vec.erase(filter_shape_vec.begin(),
|
|
|
|
|
filter_shape_vec.begin() + 2);
|
|
|
|
|
|
|
|
|
|
// output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
|
|
|
|
|
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
|
|
|
|
|
std::vector<int64_t> output_shape_vec(framework::vectorize(output->dims()));
|
|
|
|
|
output_shape_vec.erase(output_shape_vec.begin(),
|
|
|
|
|
output_shape_vec.begin() + 2);
|
|
|
|
|
|
|
|
|
|
// use col_shape in the im2col calculation
|
|
|
|
|
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
|
|
|
|
|
// o_h, o_w}
|
|
|
|
|
std::vector<int64_t> col_shape_vec;
|
|
|
|
|
col_shape_vec.push_back(input->dims()[1] / groups);
|
|
|
|
|
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
|
|
|
|
|
filter_shape_vec.end());
|
|
|
|
|
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(),
|
|
|
|
|
output_shape_vec.end());
|
|
|
|
|
size_t data_dim = filter_shape_vec.size() - 2;
|
|
|
|
|
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
|
|
|
|
|
col_shape_vec[0] = input->dims()[1] / groups;
|
|
|
|
|
for (size_t j = 0; j < data_dim; ++j) {
|
|
|
|
|
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
|
|
|
|
|
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
|
|
|
|
|
}
|
|
|
|
|
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
|
|
|
|
|
|
|
|
|
|
// use col_matrix_shape in the gemm calculation
|
|
|
|
|
// size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d *
|
|
|
|
|
// o_h * o_w)
|
|
|
|
|
framework::DDim col_matrix_shape =
|
|
|
|
|
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
|
|
|
|
|
framework::flatten_to_2d(col_shape, data_dim + 1);
|
|
|
|
|
|
|
|
|
|
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
|
|
|
|
|
Tensor col;
|
|
|
|
@ -159,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel<T> {
|
|
|
|
|
col.ShareDataWith(in_slice);
|
|
|
|
|
col_matrix.ShareDataWith(col);
|
|
|
|
|
col_matrix.Resize(col_matrix_shape);
|
|
|
|
|
} else if (filter_shape_vec.size() == 2) {
|
|
|
|
|
} else if (data_dim == 2U) {
|
|
|
|
|
// im2col
|
|
|
|
|
im2col(context.device_context(), in_slice, dilations, strides,
|
|
|
|
|
std::vector<int>{paddings[0], paddings[1], paddings[0],
|
|
|
|
|
paddings[1]},
|
|
|
|
|
&col);
|
|
|
|
|
} else if (filter_shape_vec.size() == 3) {
|
|
|
|
|
} else if (data_dim == 3U) {
|
|
|
|
|
// vol2col
|
|
|
|
|
vol2col(context.device_context(), in_slice, dilations, strides,
|
|
|
|
|
paddings, &col);
|
|
|
|
@ -206,26 +202,22 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
const int batch_size = static_cast<int>(input->dims()[0]);
|
|
|
|
|
|
|
|
|
|
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
|
|
|
|
|
// filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w}
|
|
|
|
|
std::vector<int64_t> filter_shape_vec(framework::vectorize(filter.dims()));
|
|
|
|
|
filter_shape_vec.erase(filter_shape_vec.begin(),
|
|
|
|
|
filter_shape_vec.begin() + 2);
|
|
|
|
|
|
|
|
|
|
// output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w}
|
|
|
|
|
// output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w}
|
|
|
|
|
std::vector<int64_t> output_shape_vec(
|
|
|
|
|
framework::vectorize(output_grad->dims()));
|
|
|
|
|
output_shape_vec.erase(output_shape_vec.begin(),
|
|
|
|
|
output_shape_vec.begin() + 2);
|
|
|
|
|
|
|
|
|
|
// use col_shape in the im2col calculation
|
|
|
|
|
// col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d,
|
|
|
|
|
// o_h, o_w}
|
|
|
|
|
std::vector<int64_t> col_shape_vec;
|
|
|
|
|
col_shape_vec.push_back(input->dims()[1] / groups);
|
|
|
|
|
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
|
|
|
|
|
filter_shape_vec.end());
|
|
|
|
|
col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(),
|
|
|
|
|
output_shape_vec.end());
|
|
|
|
|
size_t data_dim = filter_shape_vec.size() - 2;
|
|
|
|
|
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
|
|
|
|
|
col_shape_vec[0] = input->dims()[1] / groups;
|
|
|
|
|
for (size_t j = 0; j < data_dim; ++j) {
|
|
|
|
|
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
|
|
|
|
|
col_shape_vec[j + 1 + data_dim] = output_shape_vec[j + 2];
|
|
|
|
|
}
|
|
|
|
|
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
|
|
|
|
|
|
|
|
|
|
// use col_matrix_shape in the gemm calculation
|
|
|
|
@ -233,7 +225,7 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
// or
|
|
|
|
|
// (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w)
|
|
|
|
|
framework::DDim col_matrix_shape =
|
|
|
|
|
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
|
|
|
|
|
framework::flatten_to_2d(col_shape, data_dim + 1);
|
|
|
|
|
|
|
|
|
|
framework::DDim input_shape = framework::slice_ddim(
|
|
|
|
|
input->dims(), 1, static_cast<int>(input->dims().size()));
|
|
|
|
@ -294,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
out_grad_slice, false, T(1.0), &col_matrix,
|
|
|
|
|
T(0.0));
|
|
|
|
|
|
|
|
|
|
if (is_expand && filter_shape_vec.size() == 2) {
|
|
|
|
|
if (is_expand && data_dim == 2U) {
|
|
|
|
|
col2im(context.device_context(), col, dilations, strides,
|
|
|
|
|
std::vector<int>{paddings[0], paddings[1], paddings[0],
|
|
|
|
|
paddings[1]},
|
|
|
|
|
&in_grad_slice);
|
|
|
|
|
} else if (is_expand && filter_shape_vec.size() == 3) {
|
|
|
|
|
} else if (is_expand && data_dim == 3U) {
|
|
|
|
|
col2vol(context.device_context(), col, dilations, strides, paddings,
|
|
|
|
|
&in_grad_slice);
|
|
|
|
|
}
|
|
|
|
@ -328,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
col.ShareDataWith(in_slice);
|
|
|
|
|
col_matrix.ShareDataWith(col);
|
|
|
|
|
col_matrix.Resize(col_matrix_shape);
|
|
|
|
|
} else if (filter_shape_vec.size() == 2) {
|
|
|
|
|
} else if (data_dim == 2U) {
|
|
|
|
|
im2col(context.device_context(), in_slice, dilations, strides,
|
|
|
|
|
std::vector<int>{paddings[0], paddings[1], paddings[0],
|
|
|
|
|
paddings[1]},
|
|
|
|
|
&col);
|
|
|
|
|
} else if (filter_shape_vec.size() == 3) {
|
|
|
|
|
} else if (data_dim == 3U) {
|
|
|
|
|
vol2col(context.device_context(), in_slice, dilations, strides,
|
|
|
|
|
paddings, &col);
|
|
|
|
|
}
|
|
|
|
|