|
|
|
@ -70,7 +70,7 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
|
|
|
|
|
if (bias) {
|
|
|
|
|
auto bias_vec = EigenMatrix<T>::From(*bias);
|
|
|
|
|
Eigen::DSizes<int, 2> bcast(batch_size, 1);
|
|
|
|
|
output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat;
|
|
|
|
|
output_mat.device(place) = bias_vec.broadcast(bcast).eval() + output_mat;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -99,13 +99,13 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto d_out_mat = EigenMatrix<T>::From(*d_out);
|
|
|
|
|
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
// Create the intermediate variable to caculate the Output(Y@Grad).
|
|
|
|
|
// Create the intermediate variable to calculate the Output(Y@Grad).
|
|
|
|
|
Tensor x_scale;
|
|
|
|
|
x_scale.mutable_data<T>(framework::make_ddim({batch_size, x_dim}),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
auto x_scale_mat = EigenMatrix<T>::From(x_scale);
|
|
|
|
|
|
|
|
|
|
// Create the intermediate variable to caculate the Output(X@Grad).
|
|
|
|
|
// Create the intermediate variable to calculate the Output(X@Grad).
|
|
|
|
|
Tensor y_scale;
|
|
|
|
|
y_scale.mutable_data<T>(framework::make_ddim({batch_size, y_dim}),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
@ -113,65 +113,64 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
math::SetConstant<DeviceContext, T> set_zero;
|
|
|
|
|
|
|
|
|
|
// Set Output(X@Grad) be zero.
|
|
|
|
|
if (d_x) {
|
|
|
|
|
d_x->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
set_zero(dev_ctx, d_x, static_cast<T>(0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Set Output(Y@Grad) be zero.
|
|
|
|
|
if (d_y) {
|
|
|
|
|
d_y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
set_zero(dev_ctx, d_y, static_cast<T>(0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (d_weight) {
|
|
|
|
|
d_weight->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto blas = math::GetBlas<DeviceContext, T>(ctx);
|
|
|
|
|
|
|
|
|
|
// Caculate the Output(X@Grad) and Output(Y@Grad).
|
|
|
|
|
if (d_x || d_y) {
|
|
|
|
|
if (d_x || d_y || d_weight) {
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_for_x(1, y_dim);
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_for_y(1, x_dim);
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_for_weight(1, x_dim);
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < out_dim; ++i) {
|
|
|
|
|
Tensor weight_i = weight->Slice(i, i + 1).Resize(
|
|
|
|
|
framework::make_ddim({x_dim, y_dim}));
|
|
|
|
|
auto output_vec = d_out_mat.chip(i, 1);
|
|
|
|
|
|
|
|
|
|
if (d_x) {
|
|
|
|
|
y_scale_mat.device(place) =
|
|
|
|
|
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
|
|
|
|
|
.broadcast(bcast_for_x) *
|
|
|
|
|
.broadcast(bcast_for_x)
|
|
|
|
|
.eval() *
|
|
|
|
|
y_mat;
|
|
|
|
|
blas.GEMM(CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1,
|
|
|
|
|
y_scale.data<T>(), weight_i.data<T>(), 1, d_x->data<T>());
|
|
|
|
|
}
|
|
|
|
|
if (d_y) {
|
|
|
|
|
x_scale_mat.device(place) =
|
|
|
|
|
|
|
|
|
|
if (d_y || d_weight) {
|
|
|
|
|
auto output_vec_y =
|
|
|
|
|
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
|
|
|
|
|
.broadcast(bcast_for_y) *
|
|
|
|
|
x_mat;
|
|
|
|
|
blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
|
|
|
|
|
x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>());
|
|
|
|
|
.broadcast(bcast_for_y)
|
|
|
|
|
.eval();
|
|
|
|
|
x_scale_mat.device(place) = output_vec_y * x_mat;
|
|
|
|
|
if (d_y) {
|
|
|
|
|
blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1,
|
|
|
|
|
x_scale.data<T>(), weight_i.data<T>(), 1, d_y->data<T>());
|
|
|
|
|
}
|
|
|
|
|
if (d_weight) {
|
|
|
|
|
Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize(
|
|
|
|
|
framework::make_ddim({x_dim, y_dim}));
|
|
|
|
|
blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1,
|
|
|
|
|
x_scale.data<T>(), y->data<T>(), 0, d_weight_i.data<T>());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Caculate the gradient of Input(Weight).
|
|
|
|
|
if (d_weight) {
|
|
|
|
|
d_weight->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_for_weight(1, x_dim);
|
|
|
|
|
for (int i = 0; i < out_dim; ++i) {
|
|
|
|
|
Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize(
|
|
|
|
|
framework::make_ddim({x_dim, y_dim}));
|
|
|
|
|
auto output_vec = d_out_mat.chip(i, 1);
|
|
|
|
|
x_scale_mat.device(place) =
|
|
|
|
|
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
|
|
|
|
|
.broadcast(bcast_for_weight) *
|
|
|
|
|
x_mat;
|
|
|
|
|
blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1,
|
|
|
|
|
x_scale.data<T>(), y->data<T>(), 0, d_weight_i.data<T>());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Caculate the gradient of Input(Bias).
|
|
|
|
|
// calculate the gradient of Input(Bias).
|
|
|
|
|
if (d_bias) {
|
|
|
|
|
d_bias->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto d_bias_mat = framework::EigenVector<T>::Flatten(*d_bias);
|
|
|
|
|