|
|
|
@ -135,7 +135,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
// col_matrix = filter * input_batch
|
|
|
|
|
// of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
|
|
|
|
|
blas.MatMul(filter, true, input_batch, false, &col_matrix);
|
|
|
|
|
blas.MatMul(filter, true, input_batch, false, static_cast<T>(1.0),
|
|
|
|
|
&col_matrix, static_cast<T>(0.0));
|
|
|
|
|
|
|
|
|
|
if (data_dim == 2U) {
|
|
|
|
|
// col2im: col_matrix -> dy
|
|
|
|
@ -267,7 +268,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
// or
|
|
|
|
|
// (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
|
|
|
|
|
// d, h, w)
|
|
|
|
|
blas.MatMul(filter, false, col_matrix, false, &input_grad_batch);
|
|
|
|
|
blas.MatMul(filter, false, col_matrix, false, static_cast<T>(1.0),
|
|
|
|
|
&input_grad_batch, static_cast<T>(0.0));
|
|
|
|
|
}
|
|
|
|
|
if (filter_grad) {
|
|
|
|
|
// input batch
|
|
|
|
@ -277,7 +279,8 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
// or
|
|
|
|
|
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
|
|
|
|
|
// k_h * k_w)
|
|
|
|
|
blas.MatMul(in_batch, false, col_matrix, true, &filter_grad_);
|
|
|
|
|
blas.MatMul(in_batch, false, col_matrix, true, static_cast<T>(1.0),
|
|
|
|
|
&filter_grad_, static_cast<T>(1.0));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|