|
|
|
@ -70,7 +70,7 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (bias) {
|
|
|
|
|
auto bias_vec = EigenMatrix<T>::From(*bias);
|
|
|
|
|
Eigen::DSizes<int, 2> bcast(batch_size, 1);
|
|
|
|
|
output_mat.device(place) = bias_vec.broadcast(bcast).eval() + output_mat;
|
|
|
|
|
output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -143,8 +143,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (d_x) {
|
|
|
|
|
y_scale_mat.device(place) =
|
|
|
|
|
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
|
|
|
|
|
.broadcast(bcast_for_x)
|
|
|
|
|
.eval() *
|
|
|
|
|
.broadcast(bcast_for_x) *
|
|
|
|
|
y_mat;
|
|
|
|
|
blas.GEMM(CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1,
|
|
|
|
|
y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>());
|
|
|
|
@ -153,8 +152,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (d_y || d_weight) {
|
|
|
|
|
auto output_vec_y =
|
|
|
|
|
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
|
|
|
|
|
.broadcast(bcast_for_y)
|
|
|
|
|
.eval();
|
|
|
|
|
.broadcast(bcast_for_y);
|
|
|
|
|
x_scale_mat.device(place) = output_vec_y * x_mat;
|
|
|
|
|
if (d_y) {
|
|
|
|
|
blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
|
|
|
|
|