|
|
|
@ -14,15 +14,22 @@
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/eigen.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
#include "paddle/operators/math/math_function.h"
|
|
|
|
|
#include "paddle/platform/transform.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using platform::Transform;
|
|
|
|
|
|
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
|
|
|
|
|
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class BilinearTensorProductKernel : public framework::OpKernel<T> {
|
|
|
|
@ -35,43 +42,45 @@ class BilinearTensorProductKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* out = ctx.Output<Tensor>("Out");
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto y_mat = EigenMatrix<T>::From(*y);
|
|
|
|
|
auto output_mat = EigenMatrix<T>::From(*out);
|
|
|
|
|
|
|
|
|
|
auto batch_size = x->dims()[0];
|
|
|
|
|
auto weight_dims = weight->dims();
|
|
|
|
|
Tensor left_mul_vec;
|
|
|
|
|
left_mul_vec.mutable_data<T>(framework::make_ddim({weight_dims[2]}),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
if (bias) {
|
|
|
|
|
out->CopyFrom(*bias, ctx.GetPlace(), ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < weight_dims[0]; ++i) {
|
|
|
|
|
auto place = ctx.GetEigenDevice<Place>();
|
|
|
|
|
|
|
|
|
|
// Create the temporary variables.
|
|
|
|
|
Tensor left_mul;
|
|
|
|
|
left_mul.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
auto left_mul_mat = EigenMatrix<T>::From(left_mul);
|
|
|
|
|
Tensor output_col;
|
|
|
|
|
output_col.mutable_data<T>(framework::make_ddim({weight_dims[0]}),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
auto output_col_vec = EigenVector<T>::From(output_col);
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < weight_dims[0]; ++i) {
|
|
|
|
|
Tensor weight_mat = weight->Slice(i, i + 1).Resize(
|
|
|
|
|
framework::make_ddim({weight_dims[1], weight_dims[2]}));
|
|
|
|
|
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans, 1,
|
|
|
|
|
weight_dims[2], weight_dims[1], 1, x->data<T>(),
|
|
|
|
|
weight_mat.data<T>(), 0, left_mul_vec.data<T>());
|
|
|
|
|
if (bias) {
|
|
|
|
|
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
|
|
|
|
|
1, 1, weight_dims[2], 1, left_mul_vec.data<T>(),
|
|
|
|
|
y->data<T>(), 1, &(out->data<T>()[i]));
|
|
|
|
|
} else {
|
|
|
|
|
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
|
|
|
|
|
1, 1, weight_dims[2], 1, left_mul_vec.data<T>(),
|
|
|
|
|
y->data<T>(), 0, &(out->data<T>()[i]));
|
|
|
|
|
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
|
|
|
|
|
batch_size, weight_dims[2], weight_dims[1], 1,
|
|
|
|
|
x->data<T>(), weight_mat.data<T>(), 0,
|
|
|
|
|
left_mul.data<T>());
|
|
|
|
|
output_col_vec = (left_mul_mat * y_mat).sum(Eigen::DSizes<int, 1>(1));
|
|
|
|
|
for (size_t j = 0; j < batch_size; ++j) {
|
|
|
|
|
output_mat(j, i) = output_col_vec(j);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (bias) {
|
|
|
|
|
auto bias_vec = EigenMatrix<T>::From(*bias);
|
|
|
|
|
Eigen::DSizes<int, 2> bcast(batch_size, 1);
|
|
|
|
|
output_mat.device(place) = bias_vec.broadcast(bcast) + output_mat;
|
|
|
|
|
} else {
|
|
|
|
|
output_mat.device(place) = output_mat;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ScaleFunctor {
|
|
|
|
|
public:
|
|
|
|
|
explicit ScaleFunctor(const T* scale) : scale_(scale) {}
|
|
|
|
|
|
|
|
|
|
HOSTDEVICE T operator()(const T& x) const { return x * (*scale_); }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
const T* scale_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -84,66 +93,65 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
Tensor* d_weight = ctx.Output<Tensor>(framework::GradVarName("Weight"));
|
|
|
|
|
Tensor* d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
|
|
|
|
|
const Tensor* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* d_out_ptr = d_out->data<T>();
|
|
|
|
|
|
|
|
|
|
auto batch_size = x->dims()[0];
|
|
|
|
|
auto weight_dims = weight->dims();
|
|
|
|
|
|
|
|
|
|
// Get the first matrix of Weight.
|
|
|
|
|
Tensor weight_mat_0 = weight->Slice(0, 1).Resize(
|
|
|
|
|
framework::make_ddim({weight_dims[1], weight_dims[2]}));
|
|
|
|
|
auto x_mat = EigenMatrix<T>::From(*x);
|
|
|
|
|
auto y_mat = EigenMatrix<T>::From(*y);
|
|
|
|
|
auto d_out_mat = EigenMatrix<T>::From(*d_out);
|
|
|
|
|
auto place = ctx.GetEigenDevice<Place>();
|
|
|
|
|
|
|
|
|
|
// Create the intermediate variable for gradient.
|
|
|
|
|
int numel_x = x->numel();
|
|
|
|
|
int numel_y = y->numel();
|
|
|
|
|
const T* x_ptr = x->data<T>();
|
|
|
|
|
const T* y_ptr = y->data<T>();
|
|
|
|
|
// Create the temporary variables for gradient.
|
|
|
|
|
Tensor x_scale;
|
|
|
|
|
T* x_scale_ptr = x_scale.mutable_data<T>(
|
|
|
|
|
framework::make_ddim({weight_dims[1]}), ctx.GetPlace());
|
|
|
|
|
x_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[1]}),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
auto x_scale_mat = EigenMatrix<T>::From(x_scale);
|
|
|
|
|
Tensor y_scale;
|
|
|
|
|
T* y_scale_ptr = y_scale.mutable_data<T>(
|
|
|
|
|
framework::make_ddim({weight_dims[2]}), ctx.GetPlace());
|
|
|
|
|
Transform<Place> trans;
|
|
|
|
|
y_scale.mutable_data<T>(framework::make_ddim({batch_size, weight_dims[2]}),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
auto y_scale_mat = EigenMatrix<T>::From(y_scale);
|
|
|
|
|
|
|
|
|
|
math::SetConstant<Place, T> set_zero;
|
|
|
|
|
|
|
|
|
|
// Caculate the gradient of X according to the first matrix of Weight.
|
|
|
|
|
// Set X@Grad be zero at first.
|
|
|
|
|
if (d_x) {
|
|
|
|
|
d_x->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
trans(ctx.device_context(), y_ptr, y_ptr + numel_y, y_scale_ptr,
|
|
|
|
|
ScaleFunctor<T>(&d_out_ptr[0]));
|
|
|
|
|
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasTrans, 1,
|
|
|
|
|
weight_dims[1], weight_dims[2], 1, y_scale.data<T>(),
|
|
|
|
|
weight_mat_0.data<T>(), 0, d_x->data<T>());
|
|
|
|
|
set_zero(ctx.device_context(), d_x, static_cast<T>(0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Caculate the gradient of Y according to the first matrix of Weight.
|
|
|
|
|
// Set Y@Grad be zero at first.
|
|
|
|
|
if (d_y) {
|
|
|
|
|
d_y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
trans(ctx.device_context(), x_ptr, x_ptr + numel_x, x_scale_ptr,
|
|
|
|
|
ScaleFunctor<T>(&d_out_ptr[0]));
|
|
|
|
|
math::gemm<Place, T>(ctx.device_context(), CblasTrans, CblasNoTrans,
|
|
|
|
|
weight_dims[2], 1, weight_dims[1], 1,
|
|
|
|
|
weight_mat_0.data<T>(), x_scale.data<T>(), 0,
|
|
|
|
|
d_y->data<T>());
|
|
|
|
|
set_zero(ctx.device_context(), d_y, static_cast<T>(0));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Caculate the gradient of X and Y completly.
|
|
|
|
|
// Caculate the X@Grad and Y@Grad.
|
|
|
|
|
if (d_x || d_y) {
|
|
|
|
|
for (int i = 1; i < weight_dims[0]; ++i) {
|
|
|
|
|
Tensor weight_mat = weight->Slice(i, i + 1).Resize(
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_for_x(1, weight_dims[2]);
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_for_y(1, weight_dims[1]);
|
|
|
|
|
for (int i = 0; i < weight_dims[0]; ++i) {
|
|
|
|
|
Tensor weight_i = weight->Slice(i, i + 1).Resize(
|
|
|
|
|
framework::make_ddim({weight_dims[1], weight_dims[2]}));
|
|
|
|
|
auto output_vec = d_out_mat.chip(i, 1);
|
|
|
|
|
if (d_x) {
|
|
|
|
|
trans(ctx.device_context(), y_ptr, y_ptr + numel_y, y_scale_ptr,
|
|
|
|
|
ScaleFunctor<T>(&d_out_ptr[i]));
|
|
|
|
|
y_scale_mat.device(place) =
|
|
|
|
|
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
|
|
|
|
|
.broadcast(bcast_for_x) *
|
|
|
|
|
y_mat;
|
|
|
|
|
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasTrans,
|
|
|
|
|
1, weight_dims[1], weight_dims[2], 1,
|
|
|
|
|
y_scale.data<T>(), weight_mat.data<T>(), 1,
|
|
|
|
|
batch_size, weight_dims[1], weight_dims[2], 1,
|
|
|
|
|
y_scale.data<T>(), weight_i.data<T>(), 1,
|
|
|
|
|
d_x->data<T>());
|
|
|
|
|
}
|
|
|
|
|
if (d_y) {
|
|
|
|
|
trans(ctx.device_context(), x_ptr, x_ptr + numel_x, x_scale_ptr,
|
|
|
|
|
ScaleFunctor<T>(&d_out_ptr[i]));
|
|
|
|
|
math::gemm<Place, T>(ctx.device_context(), CblasTrans, CblasNoTrans,
|
|
|
|
|
weight_dims[2], 1, weight_dims[1], 1,
|
|
|
|
|
weight_mat.data<T>(), x_scale.data<T>(), 1,
|
|
|
|
|
x_scale_mat.device(place) =
|
|
|
|
|
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
|
|
|
|
|
.broadcast(bcast_for_y) *
|
|
|
|
|
x_mat;
|
|
|
|
|
math::gemm<Place, T>(ctx.device_context(), CblasNoTrans, CblasNoTrans,
|
|
|
|
|
batch_size, weight_dims[2], weight_dims[1], 1,
|
|
|
|
|
x_scale.data<T>(), weight_i.data<T>(), 1,
|
|
|
|
|
d_y->data<T>());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -152,22 +160,27 @@ class BilinearTensorProductGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
// Caculate the gradient of Weight.
|
|
|
|
|
if (d_weight) {
|
|
|
|
|
d_weight->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
Eigen::DSizes<int, 2> bcast_for_weight(1, weight_dims[1]);
|
|
|
|
|
for (int i = 0; i < weight_dims[0]; ++i) {
|
|
|
|
|
Tensor d_weight_mat = d_weight->Slice(i, i + 1).Resize(
|
|
|
|
|
Tensor d_weight_i = d_weight->Slice(i, i + 1).Resize(
|
|
|
|
|
framework::make_ddim({weight_dims[1], weight_dims[2]}));
|
|
|
|
|
trans(ctx.device_context(), x_ptr, x_ptr + numel_x, x_scale_ptr,
|
|
|
|
|
ScaleFunctor<T>(&d_out_ptr[i]));
|
|
|
|
|
auto output_vec = d_out_mat.chip(i, 1);
|
|
|
|
|
x_scale_mat.device(place) =
|
|
|
|
|
output_vec.reshape(Eigen::DSizes<int, 2>(batch_size, 1))
|
|
|
|
|
.broadcast(bcast_for_weight) *
|
|
|
|
|
x_mat;
|
|
|
|
|
math::gemm<Place, T>(ctx.device_context(), CblasTrans, CblasNoTrans,
|
|
|
|
|
weight_dims[1], weight_dims[2], 1, 1,
|
|
|
|
|
weight_dims[1], weight_dims[2], batch_size, 1,
|
|
|
|
|
x_scale.data<T>(), y->data<T>(), 0,
|
|
|
|
|
d_weight_mat.data<T>());
|
|
|
|
|
d_weight_i.data<T>());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Caculate the gradient of Bias.
|
|
|
|
|
if (d_bias) {
|
|
|
|
|
d_bias->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
d_bias->CopyFrom(*d_out, ctx.GetPlace(), ctx.device_context());
|
|
|
|
|
auto d_bias_mat = EigenMatrix<T>::From(*d_bias);
|
|
|
|
|
d_bias_mat.device(place) = d_out_mat.sum(Eigen::DSizes<int, 1>(0));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|