|
|
|
@ -27,10 +27,6 @@ template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
|
|
|
|
|
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class BilinearTensorProductKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -49,7 +45,9 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto weight_dims = weight->dims();
|
|
|
|
|
auto place = ctx.GetEigenDevice<Place>();
|
|
|
|
|
|
|
|
|
|
// Create the intermediate variables.
|
|
|
|
|
// Create the intermediate variable to caculate the result of
|
|
|
|
|
// Input(X) multiplied by Input(Weight_i), the formula is:
|
|
|
|
|
// left_mul = X Weight_i.
|
|
|
|
|
Tensor left_mul;
|
|
|
|
|
left_mul.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
@ -95,11 +93,13 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto d_out_mat = EigenMatrix<T>::From(*d_out);
|
|
|
|
|
auto place = ctx.GetEigenDevice<Place>();
|
|
|
|
|
|
|
|
|
|
// Create the intermediate variables for gradient.
|
|
|
|
|
// Create the intermediate variable to caculate the Output(Y@Grad).
|
|
|
|
|
Tensor x_scale;
|
|
|
|
|
x_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[1]}),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
auto x_scale_mat = EigenMatrix<T>::From(x_scale);
|
|
|
|
|
|
|
|
|
|
// Create the intermediate variable to caculate the Output(X@Grad).
|
|
|
|
|
Tensor y_scale;
|
|
|
|
|
y_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
@ -107,19 +107,19 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
|
|
|
|
|
math::SetConstant<Place, T> set_zero;
|
|
|
|
|
|
|
|
|
|
// Set X@Grad be zero at first.
|
|
|
|
|
// Set Output(X@Grad) be zero.
|
|
|
|
|
if (d_x) {
|
|
|
|
|
d_x->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
set_zero(ctx.device_context(), d_x, static_cast<T>(0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Set Y@Grad be zero at first.
|
|
|
|
|
// Set Output(Y@Grad) be zero.
|
|
|
|
|
if (d_y) {
|
|
|
|
|
d_y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
set_zero(ctx.device_context(), d_y, static_cast<T>(0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Caculate the X@Grad and Y@Grad.
|
|
|
|
|
// Caculate the Output(X@Grad) and Output(Y@Grad).
|
|
|
|
|
if (d_x || d_y) {
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_for_x(1, weight_dims[2]);
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_for_y(1, weight_dims[1]);
|
|
|
|
@ -150,7 +150,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Caculate the gradient of Weight.
|
|
|
|
|
// Caculate the gradient of Input(Weight).
|
|
|
|
|
if (d_weight) {
|
|
|
|
|
d_weight->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_for_weight(1, weight_dims[1]);
|
|
|
|
@ -169,7 +169,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Caculate the gradient of Bias.
|
|
|
|
|
// Caculate the gradient of Input(Bias).
|
|
|
|
|
if (d_bias) {
|
|
|
|
|
d_bias->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto d_bias_mat = EigenMatrix<T>::From(*d_bias);
|
|
|
|
|