|
|
|
@ -63,7 +63,6 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
|
|
|
|
|
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
|
|
|
|
|
batch_size, y_dim, x_dim, 1, x->data<T>(),
|
|
|
|
|
weight_mat.data<T>(), 0, left_mul.data<T>());
|
|
|
|
|
Eigen::array<int, 2> shape({{static_cast<int>(out->dims()[0]), 1}});
|
|
|
|
|
output_col_vec.device(place) =
|
|
|
|
|
(left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
|
|
|
|
|
}
|
|
|
|
|