|
|
|
@ -63,6 +63,7 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
|
|
|
|
|
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
|
|
|
|
|
batch_size, y_dim, x_dim, 1, x->data<T>(),
|
|
|
|
|
weight_mat.data<T>(), 0, left_mul.data<T>());
|
|
|
|
|
Eigen::array<int, 2> shape({{static_cast<int>(out->dims()[0]), 1}});
|
|
|
|
|
output_col_vec.device(place) =
|
|
|
|
|
(left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
|
|
|
|
|
}
|
|
|
|
@ -174,7 +175,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
// Caculate the gradient of Input(Bias).
|
|
|
|
|
if (d_bias) {
|
|
|
|
|
d_bias->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto d_bias_mat = EigenMatrix<T>::From(*d_bias);
|
|
|
|
|
auto d_bias_mat = framework::EigenVector<T>::Flatten(*d_bias);
|
|
|
|
|
d_bias_mat.device(place) = d_out_mat.sum(Eigen::DSizes<int, 1>(0));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|