|
|
|
@ -46,6 +46,7 @@ class FSPOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
x_mat_desc.width_ = height * width;
|
|
|
|
|
x_mat_desc.batch_size_ = batch_size;
|
|
|
|
|
x_mat_desc.stride_ = x_channel * height * width;
|
|
|
|
|
x_mat_desc.trans_ = false;
|
|
|
|
|
|
|
|
|
|
math::MatDescriptor y_mat_desc;
|
|
|
|
|
y_mat_desc.height_ = height * width;
|
|
|
|
@ -93,12 +94,14 @@ class FSPGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
d_out_mat_desc.width_ = y_channel;
|
|
|
|
|
d_out_mat_desc.batch_size_ = batch_size;
|
|
|
|
|
d_out_mat_desc.stride_ = x_channel * y_channel;
|
|
|
|
|
d_out_mat_desc.trans_ = false;
|
|
|
|
|
|
|
|
|
|
math::MatDescriptor y_mat_desc;
|
|
|
|
|
y_mat_desc.height_ = y_channel;
|
|
|
|
|
y_mat_desc.width_ = h * w;
|
|
|
|
|
y_mat_desc.batch_size_ = batch_size;
|
|
|
|
|
y_mat_desc.stride_ = y_channel * h * w;
|
|
|
|
|
y_mat_desc.trans_ = false;
|
|
|
|
|
|
|
|
|
|
blas.MatMul(*d_out, d_out_mat_desc, *y, y_mat_desc,
|
|
|
|
|
static_cast<T>(1.0 / (h * w)), d_x, static_cast<T>(0.0));
|
|
|
|
@ -125,6 +128,7 @@ class FSPGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
x_mat_desc.width_ = h * w;
|
|
|
|
|
x_mat_desc.batch_size_ = batch_size;
|
|
|
|
|
x_mat_desc.stride_ = x_channel * h * w;
|
|
|
|
|
x_mat_desc.trans_ = false;
|
|
|
|
|
|
|
|
|
|
blas.MatMul(*d_out, d_out_mat_desc, *x, x_mat_desc,
|
|
|
|
|
static_cast<T>(1.0 / (h * w)), d_y, static_cast<T>(0.0));
|
|
|
|
|