|
|
|
@ -43,24 +43,26 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto batch_size = x->dims()[0];
|
|
|
|
|
auto weight_dims = weight->dims();
|
|
|
|
|
int Out_dim = weight_dims[0];
|
|
|
|
|
int X_dim = weight_dims[1];
|
|
|
|
|
int Y_dim = weight_dims[2];
|
|
|
|
|
auto place = ctx.GetEigenDevice<Place>();
|
|
|
|
|
|
|
|
|
|
// Create the intermediate variable to caculate the result of
|
|
|
|
|
// Input(X) multiplied by Input(Weight_i), the formula is:
|
|
|
|
|
// left_mul = X Weight_i.
|
|
|
|
|
Tensor left_mul;
|
|
|
|
|
left_mul.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
|
|
|
|
|
left_mul.mutable_data<T>(framework::make_ddim({batch_size, Y_dim}),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
auto left_mul_mat = EigenMatrix<T>::From(left_mul);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < weight_dims[0]; ++i) {
|
|
|
|
|
for (int i = 0; i < Out_dim; ++i) {
|
|
|
|
|
auto output_col_vec = output_mat.chip(i, 1);
|
|
|
|
|
Tensor weight_mat = weight->Slice(i, i + 1).Resize(
|
|
|
|
|
framework::make_ddim({weight_dims[1], weight_dims[2]}));
|
|
|
|
|
Tensor weight_mat =
|
|
|
|
|
weight->Slice(i, i + 1).Resize(framework::make_ddim({X_dim, Y_dim}));
|
|
|
|
|
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
|
|
|
|
|
batch_size, weight_dims[2], weight_dims[1], 1,
|
|
|
|
|
x->data<T>(), weight_mat.data<T>(), 0,
|
|
|
|
|
left_mul.data<T>());
|
|
|
|
|
batch_size, Y_dim, X_dim, 1, x->data<T>(),
|
|
|
|
|
weight_mat.data<T>(), 0, left_mul.data<T>());
|
|
|
|
|
output_col_vec.device(place) =
|
|
|
|
|
(left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
|
|
|
|
|
}
|
|
|
|
@ -87,6 +89,9 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
auto batch_size = x->dims()[0];
|
|
|
|
|
auto weight_dims = weight->dims();
|
|
|
|
|
int Out_dim = weight_dims[0];
|
|
|
|
|
int X_dim = weight_dims[1];
|
|
|
|
|
int Y_dim = weight_dims[2];
|
|
|
|
|
|
|
|
|
|
auto x_mat = EigenMatrix<T>::From(*x);
|
|
|
|
|
auto y_mat = EigenMatrix<T>::From(*y);
|
|
|
|
@ -95,13 +100,13 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
// Create the intermediate variable to caculate the Output(Y@Grad).
|
|
|
|
|
Tensor x_scale;
|
|
|
|
|
x_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[1]}),
|
|
|
|
|
x_scale.mutable_data<T>(framework::make_ddim({batch_size, X_dim}),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
auto x_scale_mat = EigenMatrix<T>::From(x_scale);
|
|
|
|
|
|
|
|
|
|
// Create the intermediate variable to caculate the Output(X@Grad).
|
|
|
|
|
Tensor y_scale;
|
|
|
|
|
y_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
|
|
|
|
|
y_scale.mutable_data<T>(framework::make_ddim({batch_size, Y_dim}),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
auto y_scale_mat = EigenMatrix<T>::From(y_scale);
|
|
|
|
|
|
|
|
|
@ -121,11 +126,11 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
// Caculate the Output(X@Grad) and Output(Y@Grad).
|
|
|
|
|
if (d_x || d_y) {
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_for_x(1, weight_dims[2]);
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_for_y(1, weight_dims[1]);
|
|
|
|
|
for (int i = 0; i < weight_dims[0]; ++i) {
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_for_x(1, Y_dim);
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_for_y(1, X_dim);
|
|
|
|
|
for (int i = 0; i < Out_dim; ++i) {
|
|
|
|
|
Tensor weight_i = weight->Slice(i, i + 1).Resize(
|
|
|
|
|
framework::make_ddim({weight_dims[1], weight_dims[2]}));
|
|
|
|
|
framework::make_ddim({X_dim, Y_dim}));
|
|
|
|
|
auto output_vec = d_out_mat.chip(i, 1);
|
|
|
|
|
if (d_x) {
|
|
|
|
|
y_scale_mat.device(place) =
|
|
|
|
@ -133,9 +138,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
.broadcast(bcast_for_x) *
|
|
|
|
|
y_mat;
|
|
|
|
|
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasTrans,
|
|
|
|
|
batch_size, weight_dims[1], weight_dims[2], 1,
|
|
|
|
|
y_scale.data<T>(), weight_i.data<T>(), 1,
|
|
|
|
|
d_x->data<T>());
|
|
|
|
|
batch_size, X_dim, Y_dim, 1, y_scale.data<T>(),
|
|
|
|
|
weight_i.data<T>(), 1, d_x->data<T>());
|
|
|
|
|
}
|
|
|
|
|
if (d_y) {
|
|
|
|
|
x_scale_mat.device(place) =
|
|
|
|
@ -143,9 +147,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
.broadcast(bcast_for_y) *
|
|
|
|
|
x_mat;
|
|
|
|
|
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
|
|
|
|
|
batch_size, weight_dims[2], weight_dims[1], 1,
|
|
|
|
|
x_scale.data<T>(), weight_i.data<T>(), 1,
|
|
|
|
|
d_y->data<T>());
|
|
|
|
|
batch_size, Y_dim, X_dim, 1, x_scale.data<T>(),
|
|
|
|
|
weight_i.data<T>(), 1, d_y->data<T>());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -153,19 +156,18 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
// Caculate the gradient of Input(Weight).
|
|
|
|
|
if (d_weight) {
|
|
|
|
|
d_weight->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_for_weight(1, weight_dims[1]);
|
|
|
|
|
for (int i = 0; i < weight_dims[0]; ++i) {
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_for_weight(1, X_dim);
|
|
|
|
|
for (int i = 0; i < Out_dim; ++i) {
|
|
|
|
|
Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize(
|
|
|
|
|
framework::make_ddim({weight_dims[1], weight_dims[2]}));
|
|
|
|
|
framework::make_ddim({X_dim, Y_dim}));
|
|
|
|
|
auto output_vec = d_out_mat.chip(i, 1);
|
|
|
|
|
x_scale_mat.device(place) =
|
|
|
|
|
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
|
|
|
|
|
.broadcast(bcast_for_weight) *
|
|
|
|
|
x_mat;
|
|
|
|
|
math::gemm<Place, T>(ctx.device_context(), CblasTrans, CblasNoTrans,
|
|
|
|
|
weight_dims[1], weight_dims[2], batch_size, 1,
|
|
|
|
|
x_scale.data<T>(), y->data<T>(), 0,
|
|
|
|
|
d_weight_i.data<T>());
|
|
|
|
|
X_dim, Y_dim, batch_size, 1, x_scale.data<T>(),
|
|
|
|
|
y->data<T>(), 0, d_weight_i.data<T>());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|