refine notation in bilinear_tensor_product_op.h

mobile_baidu
peterzhang2029 7 years ago
parent 5cf8204171
commit 5f99ae908b

@ -27,10 +27,6 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class BilinearTensorProductKernel : public framework::OpKernel<T> {
public:
@ -49,7 +45,9 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
auto weight_dims = weight->dims();
auto place = ctx.GetEigenDevice<Place>();
// Create the intermediate variables.
// Create the intermediate variable to caculate the result of
// Input(X) multiplied by Input(Weight_i), the formula is:
// left_mul = X Weight_i.
Tensor left_mul;
left_mul.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
ctx.GetPlace());
@ -95,11 +93,13 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
auto d_out_mat = EigenMatrix<T>::From(*d_out);
auto place = ctx.GetEigenDevice<Place>();
// Create the intermediate variables for gradient.
// Create the intermediate variable to caculate the Output(Y@Grad).
Tensor x_scale;
x_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[1]}),
ctx.GetPlace());
auto x_scale_mat = EigenMatrix<T>::From(x_scale);
// Create the intermediate variable to caculate the Output(X@Grad).
Tensor y_scale;
y_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
ctx.GetPlace());
@ -107,19 +107,19 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
math::SetConstant<Place, T> set_zero;
// Set X@Grad be zero at first.
// Set Output(X@Grad) be zero.
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
set_zero(ctx.device_context(), d_x, static_cast<T>(0));
}
// Set Y@Grad be zero at first.
// Set Output(Y@Grad) be zero.
if (d_y) {
d_y->mutable_data<T>(ctx.GetPlace());
set_zero(ctx.device_context(), d_y, static_cast<T>(0));
}
// Caculate the X@Grad and Y@Grad.
// Caculate the Output(X@Grad) and Output(Y@Grad).
if (d_x || d_y) {
Eigen::DSizes<int, 2> bcast_for_x(1, weight_dims[2]);
Eigen::DSizes<int, 2> bcast_for_y(1, weight_dims[1]);
@ -150,7 +150,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
}
}
// Caculate the gradient of Weight.
// Caculate the gradient of Input(Weight).
if (d_weight) {
d_weight->mutable_data<T>(ctx.GetPlace());
Eigen::DSizes<int, 2> bcast_for_weight(1, weight_dims[1]);
@ -169,7 +169,7 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
}
}
// Caculate the gradient of Bias.
// Caculate the gradient of Input(Bias).
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
auto d_bias_mat = EigenMatrix<T>::From(*d_bias);

Loading…
Cancel
Save