@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <cublas.h>
#include <algorithm>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/rank_attention.cu.h"
@ -32,7 +33,10 @@ class RankAttentionCUDAKernel : public framework::OpKernel<T> {
auto *X = ctx.Input<Tensor>("X");
auto *rank_offset = ctx.Input<Tensor>("RankOffset");
auto *param = ctx.Input<Tensor>("RankParam");
auto *input_help = ctx.Output<Tensor>("InputHelp");
auto *ins_rank = ctx.Output<Tensor>("InsRank");
int max_rank = ctx.Attr<int>("MaxRank");
int64_t max_size = ctx.Attr<int>("MaxSize");
auto *Out = ctx.Output<Tensor>("Out");
// check dims
@ -56,37 +60,42 @@ class RankAttentionCUDAKernel : public framework::OpKernel<T> {
int block_matrix_row = max_rank * x_fea_dim;
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto stream = ctx.cuda_device_context().stream();
int device_id = platform::GetCurrentDeviceId();
T *param_help_data;
auto param_help_size = ins_num * block_matrix_row * para_col * sizeof(T);
platform::RecordedCudaMalloc(reinterpret_cast<void **>(¶m_help_data),
param_help_size, device_id);
platform::GpuMemsetAsync(param_help_data, 0, param_help_size, stream);
T *input_help_data;
auto input_help_size = ins_num * block_matrix_row * sizeof(T);
platform::RecordedCudaMalloc(reinterpret_cast<void **>(&input_help_data),
input_help_size, device_id);
platform::GpuMemsetAsync(input_help_data, 0, input_help_size, stream);
T *ins_rank_data;
auto ins_rank_size = ins_num * sizeof(T);
platform::RecordedCudaMalloc(reinterpret_cast<void **>(&ins_rank_data),
ins_rank_size, device_id);
platform::GpuMemsetAsync(ins_rank_data, -1, ins_rank_size, stream);
int max_ins = std::max(ins_num, max_size);
Tensor param_help;
param_help = ctx.AllocateTmpTensor<T, DeviceContext>(
{max_ins * block_matrix_row, para_col}, dev_ctx);
param_help.mutable_data<T>(ctx.GetPlace());
input_help->Resize({max_ins, block_matrix_row});
ins_rank->Resize({max_ins, 1});
input_help->mutable_data<T>(ctx.GetPlace());
ins_rank->mutable_data<T>(ctx.GetPlace());
Out->mutable_data<T>(ctx.GetPlace());
// initialize
auto param_help_eigen = framework::EigenVector<T>::Flatten(param_help);
auto input_help_eigen = framework::EigenVector<T>::Flatten(*input_help);
auto ins_rank_eigen = framework::EigenVector<T>::Flatten(*ins_rank);
auto out_eigen = framework::EigenVector<T>::Flatten(*Out);
auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
param_help_eigen.device(place) =
param_help_eigen.constant(static_cast<T>(0));
input_help_eigen.device(place) =
input_help_eigen.constant(static_cast<T>(0));
ins_rank_eigen.device(place) = ins_rank_eigen.constant(static_cast<T>(-1));
out_eigen.device(place) = out_eigen.constant(static_cast<T>(0));
// get data ptr
T *input_help_data = input_help->data<T>();
T *param_help_data = param_help.data<T>();
T *ins_rank_data = ins_rank->data<T>();
T *out_data = Out->data<T>();
expand_rank_attention_input(
ctx.cuda_device_context().stream(), X->data<T>(), ins_num, x_fea_dim,
input_help_data, ins_num, block_matrix_row, rank_offset->data<int>(),
@ -110,10 +119,6 @@ class RankAttentionCUDAKernel : public framework::OpKernel<T> {
blas.BatchedGEMM(transA, transB, 1, para_col, block_matrix_row, alpha,
input_help_data, param_help_data, beta, out_data, ins_num,
strideA, strideB);
platform::RecordedCudaFree(param_help_data, param_help_size, device_id);
platform::RecordedCudaFree(input_help_data, input_help_size, device_id);
platform::RecordedCudaFree(ins_rank_data, ins_rank_size, device_id);
}
};
@ -121,10 +126,13 @@ template <typename DeviceContext, typename T>
class RankAttentionGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X");
auto *rank_offset = ctx.Input<Tensor>("RankOffset");
auto *param = ctx.Input<Tensor>("RankParam");
auto *X = ctx.Input<Tensor>("X"); // not use data
auto *rank_offset = ctx.Input<Tensor>("RankOffset"); // not use data
auto *param = ctx.Input<Tensor>("RankParam"); // not use data
auto *input_help = ctx.Input<Tensor>("InputHelp");
auto *ins_rank = ctx.Input<Tensor>("InsRank");
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
int64_t max_size = ctx.Attr<int>("MaxSize");
auto *drank_para = ctx.Output<Tensor>(framework::GradVarName("RankParam"));
@ -142,38 +150,26 @@ class RankAttentionGradOpCUDAKernel : public framework::OpKernel<T> {
auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
int max_ins = std::max(ins_num, max_size);
// initialize out grad
drank_para->mutable_data<T>(ctx.GetPlace());
auto drank_para_eigen = framework::EigenVector<T>::Flatten(*drank_para);
drank_para_eigen.device(place) =
drank_para_eigen.constant(static_cast<T>(0));
auto stream = ctx.cuda_device_context().stream();
int device_id = platform::GetCurrentDeviceId();
T *param_grad_data;
auto param_grad_size = ins_num * block_matrix_row * para_col * sizeof(T);
platform::RecordedCudaMalloc(reinterpret_cast<void **>(¶m_grad_data),
param_grad_size, device_id);
platform::GpuMemsetAsync(param_grad_data, 0, param_grad_size, stream);
T *input_help_data;
auto input_help_size = ins_num * block_matrix_row * sizeof(T);
platform::RecordedCudaMalloc(reinterpret_cast<void **>(&input_help_data),
input_help_size, device_id);
platform::GpuMemsetAsync(input_help_data, 0, input_help_size, stream);
T *ins_rank_data;
auto ins_rank_size = ins_num * sizeof(T);
platform::RecordedCudaMalloc(reinterpret_cast<void **>(&ins_rank_data),
ins_rank_size, device_id);
platform::GpuMemsetAsync(ins_rank_data, -1, ins_rank_size, stream);
// expand input
expand_rank_attention_input(
ctx.cuda_device_context().stream(), X->data<T>(), ins_num, x_fea_dim,
input_help_data, ins_num, block_matrix_row, rank_offset->data<int>(),
rank_offset_dims[0], rank_offset_dims[1], ins_rank_data, max_rank);
// copy data
Tensor param_grad;
param_grad = ctx.AllocateTmpTensor<T, DeviceContext>(
{max_ins * block_matrix_row, para_col}, dev_ctx);
param_grad.mutable_data<T>(ctx.GetPlace());
// initialize
auto param_grad_eigen = framework::EigenVector<T>::Flatten(param_grad);
param_grad_eigen.device(place) =
param_grad_eigen.constant(static_cast<T>(0));
// get data ptr
const T *input_help_data = input_help->data<T>();
const T *ins_rank_data = ins_rank->data<T>();
T *param_grad_data = param_grad.data<T>();
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx);
T alpha = 1;
@ -184,20 +180,14 @@ class RankAttentionGradOpCUDAKernel : public framework::OpKernel<T> {
CBLAS_TRANSPOSE transB = CblasNoTrans;
int64_t strideA = block_matrix_row;
int64_t strideB = para_col;
blas.BatchedGEMM(transA, transB, block_matrix_row, para_col, 1, alpha,
input_help_data, dout->data<T>(), beta, param_grad_data,
ins_num, strideA, strideB);
// merge param_grad to get drank_para
merge_rank_attention_param_grad(
ctx.cuda_device_context().stream(), param_grad_data,
ins_num * block_matrix_row, para_col, drank_para->data<T>(), para_row,
para_col, ins_rank_data, ins_num, max_rank, x_fea_dim);
platform::RecordedCudaFree(param_grad_data, param_grad_size, device_id);
platform::RecordedCudaFree(input_help_data, input_help_size, device_id);
platform::RecordedCudaFree(ins_rank_data, ins_rank_size, device_id);
}
};