fix conflict, test=develop (#23298)
parent
5223e2bbc4
commit
c706ff20a3
@ -0,0 +1,153 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/dim.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
#define CUDA_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
const int CUDA_NUM_THREADS = 1024;
|
||||
static inline int GET_BLOCKS(const int N) {
|
||||
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void expand_input_by_rank_kernel(
|
||||
const T* input, int input_row, int input_col, T* output, int output_row,
|
||||
int output_col, const int* rank_offset, int rank_offset_row,
|
||||
int rank_offset_col, T* ins_rank, int max_rank) {
|
||||
CUDA_KERNEL_LOOP(idx, output_row * output_col) {
|
||||
int output_col_idx = idx % output_col;
|
||||
int output_row_idx = idx / output_col;
|
||||
int k = output_col_idx / input_col;
|
||||
|
||||
int faster = rank_offset[output_row_idx * rank_offset_col + 2 * k + 1] - 1;
|
||||
if (output_col_idx == 0) {
|
||||
ins_rank[output_row_idx] = rank_offset[output_row_idx * rank_offset_col];
|
||||
}
|
||||
|
||||
if (rank_offset[output_row_idx * rank_offset_col] - 1 < 0 || faster < 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int rank_input_col_idx = output_col_idx % input_col;
|
||||
int index = rank_offset[output_row_idx * rank_offset_col + 2 * k + 2];
|
||||
output[idx] = input[rank_input_col_idx + index * input_col];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void expand_rank_attention_input(cudaStream_t stream, const T* input,
|
||||
int input_row, int input_col, T* output,
|
||||
int output_row, int output_col,
|
||||
const int* rank_offset, int rank_offset_row,
|
||||
int rank_offset_col, T* ins_rank,
|
||||
int max_rank) {
|
||||
expand_input_by_rank_kernel<<<GET_BLOCKS(output_row * output_col),
|
||||
CUDA_NUM_THREADS, 0, stream>>>(
|
||||
input, input_row, input_col, output, output_row, output_col, rank_offset,
|
||||
rank_offset_row, rank_offset_col, ins_rank, max_rank);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void expand_rank_attention_param_kernel(
|
||||
const T* input, int input_row, int input_col, const int* rank_offset,
|
||||
int rank_offset_row, int rank_offset_col, const T* param, int param_row,
|
||||
int param_col, T* output_param, int output_param_row, int output_param_col,
|
||||
int max_rank) {
|
||||
CUDA_KERNEL_LOOP(idx, output_param_row * output_param_col) {
|
||||
int output_col_idx = idx % output_param_col;
|
||||
int output_row_idx = idx / output_param_col;
|
||||
|
||||
int block_matrix_row = max_rank * input_col;
|
||||
int ins_idx = output_row_idx / block_matrix_row;
|
||||
int start_offset = output_row_idx % block_matrix_row;
|
||||
|
||||
int k = start_offset / input_col;
|
||||
int k_offset = start_offset % input_col;
|
||||
|
||||
int lower = rank_offset[ins_idx * rank_offset_col] - 1;
|
||||
int faster = rank_offset[2 * k + 1 + rank_offset_col * ins_idx] - 1;
|
||||
|
||||
if (lower < 0 || faster < 0) {
|
||||
continue;
|
||||
}
|
||||
int start = lower * max_rank + faster;
|
||||
int ori_idx =
|
||||
start * param_col * input_col + k_offset * param_col + output_col_idx;
|
||||
output_param[idx] = param[ori_idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void expand_rank_attention_param(cudaStream_t stream, const T* input,
|
||||
int input_row, int input_col,
|
||||
const int* rank_offset, int rank_offset_row,
|
||||
int rank_offset_col, const T* param,
|
||||
int param_row, int param_col, T* output_param,
|
||||
int output_param_row, int output_param_col,
|
||||
int max_rank) {
|
||||
expand_rank_attention_param_kernel<<<GET_BLOCKS(output_param_row *
|
||||
output_param_col),
|
||||
CUDA_NUM_THREADS, 0, stream>>>(
|
||||
input, input_row, input_col, rank_offset, rank_offset_row,
|
||||
rank_offset_col, param, param_row, param_col, output_param,
|
||||
output_param_row, output_param_col, max_rank);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void merge_param_gradient_kernel(
|
||||
T* expanded_grad, int expanded_grad_row, int expanded_grad_col,
|
||||
T* param_grad, int param_grad_row, int param_grad_col, const T* ins_rank,
|
||||
int ins_num, int max_rank, int input_col) {
|
||||
CUDA_KERNEL_LOOP(tid, param_grad_row * param_grad_col) {
|
||||
int param_col_idx = tid % param_grad_col;
|
||||
int param_row_idx = tid / param_grad_col;
|
||||
|
||||
int block_matrix_row = max_rank * input_col;
|
||||
int rank_idx = param_row_idx / block_matrix_row;
|
||||
int rank_offset = param_row_idx % block_matrix_row;
|
||||
|
||||
T tmp = 0;
|
||||
for (int i = 0; i < ins_num; ++i) {
|
||||
if (ins_rank[i] == rank_idx + 1) {
|
||||
int row = i * block_matrix_row + rank_offset;
|
||||
tmp += expanded_grad[row * expanded_grad_col + param_col_idx];
|
||||
}
|
||||
}
|
||||
param_grad[tid] = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void merge_rank_attention_param_grad(cudaStream_t stream, T* expanded_grad,
|
||||
int expanded_grad_row,
|
||||
int expanded_grad_col, T* param_grad,
|
||||
int param_grad_row, int param_grad_col,
|
||||
const T* ins_rank, int ins_num,
|
||||
int max_rank, int input_col) {
|
||||
merge_param_gradient_kernel<<<GET_BLOCKS(param_grad_row * param_grad_col),
|
||||
CUDA_NUM_THREADS, 0, stream>>>(
|
||||
expanded_grad, expanded_grad_row, expanded_grad_col, param_grad,
|
||||
param_grad_row, param_grad_col, ins_rank, ins_num, max_rank, input_col);
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,151 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/rank_attention_op.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
class RankAttentionOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(X) of RankAttentionOp should not be null."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("RankOffset"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(RankOffset) of RankAttentionOp should not be null."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("RankParam"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(RankParam) of RankAttentionOp should not be null."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasOutput("Out"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Output(Out) of RankAttentionOp should not be null."));
|
||||
auto max_rank = ctx->Attrs().Get<int>("MaxRank");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto ins_num = x_dims[0];
|
||||
auto param_dims = ctx->GetInputDim("RankParam");
|
||||
auto para_col = param_dims[1];
|
||||
auto rank_offset_dims = ctx->GetInputDim("RankOffset");
|
||||
|
||||
PADDLE_ENFORCE_EQ((rank_offset_dims[1] - 1) / 2, max_rank,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(RankOffset) has wrong columns."));
|
||||
|
||||
ctx->SetOutputDim("Out", {ins_num, para_col});
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class RankAttentionGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
ctx->HasInput("X"), true,
|
||||
platform::errors::InvalidArgument("Input(X) should not be null"));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("RankParam"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(RankParam) should not be null"));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("RankOffset"), true,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(RankOffset) should not be null"));
|
||||
|
||||
ctx->SetOutputDim(framework::GradVarName("RankParam"),
|
||||
ctx->GetInputDim("RankParam"));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
|
||||
ctx, framework::GradVarName("Out")),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class RankAttentionOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor) Input tensor of rank_attention_Op operator.");
|
||||
AddInput("RankOffset",
|
||||
"(Tensor) Input tensor of rank_attention_Op operator.");
|
||||
AddInput("RankParam",
|
||||
"(Tensor) Input tensor of rank_attention_Op operator.");
|
||||
AddOutput("Out", "Output tensor of rank_attention_Op operator.");
|
||||
AddAttr<int>("MaxRank", "(int, default 3) max rank of rank_attention_Op")
|
||||
.SetDefault(3);
|
||||
AddComment(R"DOC(
|
||||
RankAttention Operator.
|
||||
This Op can calculate rank attention between input and rank_param,
|
||||
and rank_param gives the organization of data. Notice: It currently supports GPU device.
|
||||
This Op exists in contrib, which means that it is not shown to the public.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class RankAttentionGradOpMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> op) const override {
|
||||
op->SetType("rank_attention_grad");
|
||||
|
||||
op->SetInput("X", this->Input("X"));
|
||||
op->SetInput("RankOffset", this->Input("RankOffset"));
|
||||
op->SetInput("RankParam", this->Input("RankParam"));
|
||||
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
|
||||
op->SetOutput(framework::GradVarName("RankParam"),
|
||||
this->InputGrad("RankParam"));
|
||||
op->SetAttrMap(this->Attrs());
|
||||
}
|
||||
};
|
||||
DECLARE_NO_NEED_BUFFER_VARS_INFERER(
|
||||
RankAttentionGradOpNoNeedBufferVarsInference, "RankParam");
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(rank_attention, ops::RankAttentionOp,
|
||||
ops::RankAttentionOpMaker,
|
||||
ops::RankAttentionGradOpMaker<paddle::framework::OpDesc>,
|
||||
ops::RankAttentionGradOpMaker<paddle::imperative::OpBase>);
|
||||
|
||||
REGISTER_OPERATOR(rank_attention_grad, ops::RankAttentionGradOp,
|
||||
ops::RankAttentionGradOpNoNeedBufferVarsInference);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
rank_attention,
|
||||
ops::RankAttentionKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::RankAttentionKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,215 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <cublas.h>
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/operators/rank_attention.cu.h"
|
||||
#include "paddle/fluid/operators/rank_attention_op.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
#include "paddle/fluid/platform/gpu_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class RankAttentionCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
auto *X = ctx.Input<Tensor>("X");
|
||||
auto *rank_offset = ctx.Input<Tensor>("RankOffset");
|
||||
auto *param = ctx.Input<Tensor>("RankParam");
|
||||
int max_rank = ctx.Attr<int>("MaxRank");
|
||||
auto *Out = ctx.Output<Tensor>("Out");
|
||||
|
||||
// check dims
|
||||
auto x_dims = X->dims();
|
||||
auto ins_num = x_dims[0];
|
||||
auto x_fea_dim = x_dims[1];
|
||||
auto para_dims = param->dims();
|
||||
auto para_row = para_dims[0];
|
||||
auto para_col = para_dims[1];
|
||||
auto rank_offset_dims = rank_offset->dims();
|
||||
PADDLE_ENFORCE_EQ(
|
||||
rank_offset_dims[0], ins_num,
|
||||
platform::errors::InvalidArgument("Input(RankOffset) has wrong rows."));
|
||||
PADDLE_ENFORCE_EQ((rank_offset_dims[1] - 1) / 2, max_rank,
|
||||
platform::errors::InvalidArgument(
|
||||
"Input(RankOffset) has wrong columns."));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
max_rank * max_rank * x_fea_dim, para_row,
|
||||
platform::errors::InvalidArgument("Input(RankParam) has wrong rows."));
|
||||
|
||||
int block_matrix_row = max_rank * x_fea_dim;
|
||||
|
||||
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||
auto stream = ctx.cuda_device_context().stream();
|
||||
int device_id = platform::GetCurrentDeviceId();
|
||||
|
||||
T *param_help_data;
|
||||
auto param_help_size = ins_num * block_matrix_row * para_col * sizeof(T);
|
||||
platform::RecordedCudaMalloc(reinterpret_cast<void **>(¶m_help_data),
|
||||
param_help_size, device_id);
|
||||
platform::GpuMemsetAsync(param_help_data, 0, param_help_size, stream);
|
||||
|
||||
T *input_help_data;
|
||||
auto input_help_size = ins_num * block_matrix_row * sizeof(T);
|
||||
platform::RecordedCudaMalloc(reinterpret_cast<void **>(&input_help_data),
|
||||
input_help_size, device_id);
|
||||
platform::GpuMemsetAsync(input_help_data, 0, input_help_size, stream);
|
||||
|
||||
T *ins_rank_data;
|
||||
auto ins_rank_size = ins_num * sizeof(T);
|
||||
platform::RecordedCudaMalloc(reinterpret_cast<void **>(&ins_rank_data),
|
||||
ins_rank_size, device_id);
|
||||
platform::GpuMemsetAsync(ins_rank_data, -1, ins_rank_size, stream);
|
||||
|
||||
Out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
// initialize
|
||||
auto out_eigen = framework::EigenVector<T>::Flatten(*Out);
|
||||
auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
|
||||
.eigen_device();
|
||||
out_eigen.device(place) = out_eigen.constant(static_cast<T>(0));
|
||||
|
||||
// get data ptr
|
||||
T *out_data = Out->data<T>();
|
||||
expand_rank_attention_input(
|
||||
ctx.cuda_device_context().stream(), X->data<T>(), ins_num, x_fea_dim,
|
||||
input_help_data, ins_num, block_matrix_row, rank_offset->data<int>(),
|
||||
rank_offset_dims[0], rank_offset_dims[1], ins_rank_data, max_rank);
|
||||
|
||||
expand_rank_attention_param(
|
||||
ctx.cuda_device_context().stream(), X->data<T>(), ins_num, x_fea_dim,
|
||||
rank_offset->data<int>(), rank_offset_dims[0], rank_offset_dims[1],
|
||||
param->data<T>(), para_row, para_col, param_help_data,
|
||||
ins_num * block_matrix_row, para_col, max_rank);
|
||||
|
||||
CBLAS_TRANSPOSE transA = CblasNoTrans;
|
||||
CBLAS_TRANSPOSE transB = CblasNoTrans;
|
||||
|
||||
T alpha = 1;
|
||||
T beta = 0;
|
||||
int64_t strideA = block_matrix_row;
|
||||
int64_t strideB = block_matrix_row * para_col;
|
||||
|
||||
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx);
|
||||
blas.BatchedGEMM(transA, transB, 1, para_col, block_matrix_row, alpha,
|
||||
input_help_data, param_help_data, beta, out_data, ins_num,
|
||||
strideA, strideB);
|
||||
|
||||
platform::RecordedCudaFree(param_help_data, param_help_size, device_id);
|
||||
platform::RecordedCudaFree(input_help_data, input_help_size, device_id);
|
||||
platform::RecordedCudaFree(ins_rank_data, ins_rank_size, device_id);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class RankAttentionGradOpCUDAKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
auto *X = ctx.Input<Tensor>("X");
|
||||
auto *rank_offset = ctx.Input<Tensor>("RankOffset");
|
||||
auto *param = ctx.Input<Tensor>("RankParam");
|
||||
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
|
||||
auto *drank_para = ctx.Output<Tensor>(framework::GradVarName("RankParam"));
|
||||
|
||||
// get dim
|
||||
auto x_dims = X->dims();
|
||||
auto ins_num = x_dims[0];
|
||||
auto x_fea_dim = x_dims[1];
|
||||
auto para_dims = param->dims();
|
||||
auto para_row = para_dims[0];
|
||||
auto para_col = para_dims[1];
|
||||
auto rank_offset_dims = rank_offset->dims();
|
||||
auto max_rank = (rank_offset_dims[1] - 1) / 2;
|
||||
int block_matrix_row = max_rank * x_fea_dim;
|
||||
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||
auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
|
||||
.eigen_device();
|
||||
|
||||
// initialize out grad
|
||||
drank_para->mutable_data<T>(ctx.GetPlace());
|
||||
auto drank_para_eigen = framework::EigenVector<T>::Flatten(*drank_para);
|
||||
drank_para_eigen.device(place) =
|
||||
drank_para_eigen.constant(static_cast<T>(0));
|
||||
|
||||
auto stream = ctx.cuda_device_context().stream();
|
||||
int device_id = platform::GetCurrentDeviceId();
|
||||
|
||||
T *param_grad_data;
|
||||
auto param_grad_size = ins_num * block_matrix_row * para_col * sizeof(T);
|
||||
platform::RecordedCudaMalloc(reinterpret_cast<void **>(¶m_grad_data),
|
||||
param_grad_size, device_id);
|
||||
platform::GpuMemsetAsync(param_grad_data, 0, param_grad_size, stream);
|
||||
|
||||
T *input_help_data;
|
||||
auto input_help_size = ins_num * block_matrix_row * sizeof(T);
|
||||
platform::RecordedCudaMalloc(reinterpret_cast<void **>(&input_help_data),
|
||||
input_help_size, device_id);
|
||||
platform::GpuMemsetAsync(input_help_data, 0, input_help_size, stream);
|
||||
|
||||
T *ins_rank_data;
|
||||
auto ins_rank_size = ins_num * sizeof(T);
|
||||
platform::RecordedCudaMalloc(reinterpret_cast<void **>(&ins_rank_data),
|
||||
ins_rank_size, device_id);
|
||||
platform::GpuMemsetAsync(ins_rank_data, -1, ins_rank_size, stream);
|
||||
|
||||
// expand input
|
||||
expand_rank_attention_input(
|
||||
ctx.cuda_device_context().stream(), X->data<T>(), ins_num, x_fea_dim,
|
||||
input_help_data, ins_num, block_matrix_row, rank_offset->data<int>(),
|
||||
rank_offset_dims[0], rank_offset_dims[1], ins_rank_data, max_rank);
|
||||
|
||||
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(dev_ctx);
|
||||
T alpha = 1;
|
||||
T beta = 0;
|
||||
|
||||
// get param_grad
|
||||
CBLAS_TRANSPOSE transA = CblasTrans;
|
||||
CBLAS_TRANSPOSE transB = CblasNoTrans;
|
||||
int64_t strideA = block_matrix_row;
|
||||
int64_t strideB = para_col;
|
||||
|
||||
blas.BatchedGEMM(transA, transB, block_matrix_row, para_col, 1, alpha,
|
||||
input_help_data, dout->data<T>(), beta, param_grad_data,
|
||||
ins_num, strideA, strideB);
|
||||
|
||||
// merge param_grad to get drank_para
|
||||
merge_rank_attention_param_grad(
|
||||
ctx.cuda_device_context().stream(), param_grad_data,
|
||||
ins_num * block_matrix_row, para_col, drank_para->data<T>(), para_row,
|
||||
para_col, ins_rank_data, ins_num, max_rank, x_fea_dim);
|
||||
|
||||
platform::RecordedCudaFree(param_grad_data, param_grad_size, device_id);
|
||||
platform::RecordedCudaFree(input_help_data, input_help_size, device_id);
|
||||
platform::RecordedCudaFree(ins_rank_data, ins_rank_size, device_id);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using GPUCtx = paddle::platform::CUDADeviceContext;
|
||||
REGISTER_OP_CUDA_KERNEL(rank_attention,
|
||||
ops::RankAttentionCUDAKernel<GPUCtx, float>,
|
||||
ops::RankAttentionCUDAKernel<GPUCtx, double>);
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(rank_attention_grad,
|
||||
ops::RankAttentionGradOpCUDAKernel<GPUCtx, float>,
|
||||
ops::RankAttentionGradOpCUDAKernel<GPUCtx, double>);
|
@ -0,0 +1,32 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class RankAttentionKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
|
||||
platform::errors::Unimplemented(
|
||||
"Rank Attention only supports GPU now."));
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,221 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import random
|
||||
from op_test import OpTest
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid import Program, program_guard
|
||||
from op_test import OpTest, skip_check_grad_ci
|
||||
import paddle.fluid.core as core
|
||||
|
||||
|
||||
def gen_input_help(input, rank_offset, max_rank):
|
||||
input_row, input_col = input.shape
|
||||
input_help = np.zeros((input_row * max_rank * input_col, ))
|
||||
ins_rank = np.zeros((input_row, 1))
|
||||
ins_rank.fill(-1)
|
||||
|
||||
output_col = max_rank * input_col
|
||||
output_row = input_row
|
||||
|
||||
for idx in range(output_col * output_row):
|
||||
output_col_idx = idx % output_col
|
||||
output_row_idx = int(idx / output_col)
|
||||
k = int(output_col_idx / input_col)
|
||||
faster = rank_offset[output_row_idx, 2 * k + 1] - 1
|
||||
|
||||
if output_col_idx == 0:
|
||||
ins_rank[output_row_idx] = rank_offset[output_row_idx, 0]
|
||||
|
||||
if rank_offset[output_row_idx, 0] - 1 < 0 or faster < 0:
|
||||
continue
|
||||
|
||||
rank_input_col_idx = output_col_idx % input_col
|
||||
index = rank_offset[output_row_idx, 2 * k + 2]
|
||||
input_help[idx] = input[index, rank_input_col_idx]
|
||||
input_help = input_help.reshape([input_row, max_rank * input_col])
|
||||
|
||||
return input_help, ins_rank
|
||||
|
||||
|
||||
def gen_param_help(input, rank_offset, param, max_rank):
|
||||
input_row, input_col = input.shape
|
||||
rank_offset_row, rank_offset_col = rank_offset.shape
|
||||
param_row, param_col = param.shape
|
||||
|
||||
block_matrix_row = input_col * max_rank
|
||||
|
||||
output_param_row = block_matrix_row * input_row
|
||||
output_param_col = param_col
|
||||
|
||||
output_param = np.zeros((output_param_row * output_param_col, ))
|
||||
|
||||
for idx in range(output_param_row * output_param_col):
|
||||
output_col_idx = idx % output_param_col
|
||||
output_row_idx = int(idx / output_param_col)
|
||||
ins_idx = int(output_row_idx / block_matrix_row)
|
||||
start_offset = output_row_idx % block_matrix_row
|
||||
k = int(start_offset / input_col)
|
||||
k_offset = start_offset % input_col
|
||||
|
||||
lower = rank_offset[ins_idx, 0] - 1
|
||||
faster = rank_offset[ins_idx, 2 * k + 1] - 1
|
||||
if lower < 0 or faster < 0:
|
||||
continue
|
||||
start = lower * max_rank + faster
|
||||
ori_idx = start * param_col * input_col + k_offset * param_col + output_col_idx
|
||||
output_param[idx] = param[int(ori_idx / param_col), ori_idx % param_col]
|
||||
|
||||
output_param = output_param.reshape([output_param_row, output_param_col])
|
||||
return output_param
|
||||
|
||||
|
||||
def np_rank_attention(input, rank_offset, rank_para, max_rank):
|
||||
input_row, input_col = input.shape
|
||||
rank_offset_row, rank_offset_col = rank_offset.shape
|
||||
rank_para_row, rank_para_col = rank_para.shape
|
||||
|
||||
assert (input_row == rank_offset_row)
|
||||
assert (max_rank == ((rank_offset_col - 1) / 2))
|
||||
assert (rank_para_row == max_rank * max_rank * input_col)
|
||||
|
||||
input_help, ins_rank = gen_input_help(input, rank_offset, max_rank)
|
||||
param_help = gen_param_help(input, rank_offset, rank_para, max_rank)
|
||||
block_matrix_row = input_col * max_rank
|
||||
|
||||
res = np.zeros((input_row, rank_para_col))
|
||||
for ins in range(input_row):
|
||||
res[ins, :] = \
|
||||
np.dot(input_help[ins, :],
|
||||
param_help[int(block_matrix_row * ins):int(block_matrix_row * (ins+1)),:])
|
||||
return res, input_help, param_help, ins_rank
|
||||
|
||||
|
||||
def gen_rank_offset(pv_nums, max_rank):
|
||||
all_ins_num = 0
|
||||
pv_rank_msg = []
|
||||
for _ in range(pv_nums):
|
||||
ins_pv = np.random.randint(1, max_rank + 2) # 1~4
|
||||
rank_list = list(range(1, ins_pv + 1))
|
||||
random.shuffle(rank_list)
|
||||
all_ins_num = all_ins_num + ins_pv
|
||||
pv_rank_msg.append(rank_list)
|
||||
|
||||
rank_offset = np.zeros((all_ins_num, max_rank * 2 + 1)).astype("int32")
|
||||
rank_offset.fill(-1)
|
||||
index = 0
|
||||
for pv_number in range(len(pv_rank_msg)):
|
||||
pv_ins = pv_rank_msg[pv_number]
|
||||
ad_num = len(pv_ins)
|
||||
index_start = index
|
||||
|
||||
for j in range(ad_num):
|
||||
rank = -1
|
||||
if pv_ins[j] <= max_rank:
|
||||
rank = pv_ins[j]
|
||||
rank_offset[index, 0] = rank
|
||||
|
||||
if rank > 0:
|
||||
for k in range(ad_num):
|
||||
fast_rank = -1
|
||||
if pv_ins[k] <= max_rank:
|
||||
fast_rank = pv_ins[k]
|
||||
if fast_rank > 0:
|
||||
m = fast_rank - 1
|
||||
rank_offset[index, 2 * m + 1] = pv_ins[k]
|
||||
rank_offset[index, 2 * m + 2] = index_start + k
|
||||
index = index + 1
|
||||
return all_ins_num, rank_offset
|
||||
|
||||
|
||||
class TestRankAttentionOpComplex(OpTest):
|
||||
def config(self):
|
||||
self.pv_num = 100
|
||||
self.x_feat = 10
|
||||
self.y_feat = 15
|
||||
self.max_rank = 3
|
||||
self.dtype = "float64"
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "rank_attention"
|
||||
self.config()
|
||||
ins_num, rank_offset = gen_rank_offset(self.pv_num, self.max_rank)
|
||||
input = np.random.random((ins_num, self.x_feat)).astype(self.dtype)
|
||||
rank_para_shape = [
|
||||
self.max_rank * self.max_rank * self.x_feat, self.y_feat
|
||||
]
|
||||
rank_para = np.random.random(rank_para_shape).astype(self.dtype)
|
||||
np_out, np_input_help, np_param_help, np_ins_rank = np_rank_attention(
|
||||
input, np.array(rank_offset), rank_para, self.max_rank)
|
||||
self.inputs = {
|
||||
"X": input,
|
||||
"RankOffset": np.array(rank_offset).astype("int32"),
|
||||
"RankParam": rank_para
|
||||
}
|
||||
self.attrs = {'MaxRank': self.max_rank}
|
||||
self.outputs = {"Out": np_out}
|
||||
|
||||
def test_check_output_gpu(self):
|
||||
if core.is_compiled_with_cuda():
|
||||
self.check_output_with_place(core.CUDAPlace(0))
|
||||
|
||||
def test_check_grad_gpu(self):
|
||||
if core.is_compiled_with_cuda():
|
||||
self.check_grad_with_place(core.CUDAPlace(0), ["RankParam"], "Out")
|
||||
|
||||
|
||||
class TestRankAttentionOpCpu(OpTest):
|
||||
def config(self):
|
||||
self.pv_num = 100
|
||||
self.x_feat = 10
|
||||
self.y_feat = 15
|
||||
self.max_rank = 3
|
||||
self.dtype = "float64"
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "rank_attention"
|
||||
self.config()
|
||||
ins_num, rank_offset = gen_rank_offset(self.pv_num, self.max_rank)
|
||||
input = np.random.random((ins_num, self.x_feat)).astype(self.dtype)
|
||||
rank_para_shape = [
|
||||
self.max_rank * self.max_rank * self.x_feat, self.y_feat
|
||||
]
|
||||
rank_para = np.random.random(rank_para_shape).astype(self.dtype)
|
||||
np_out, np_input_help, np_param_help, np_ins_rank = np_rank_attention(
|
||||
input, np.array(rank_offset), rank_para, self.max_rank)
|
||||
self.inputs = {
|
||||
"X": input,
|
||||
"RankOffset": np.array(rank_offset).astype("int32"),
|
||||
"RankParam": rank_para
|
||||
}
|
||||
self.attrs = {'MaxRank': self.max_rank}
|
||||
self.outputs = {"Out": np_out}
|
||||
|
||||
def test_check_output_cpu(self):
|
||||
try:
|
||||
self.check_output_with_place(place=core.CPUPlace())
|
||||
except:
|
||||
print("do not support cpu test, skip")
|
||||
|
||||
def test_check_grad_cpu(self):
|
||||
try:
|
||||
self.check_grad_with_place(core.CPUPlace(), ["RankParam"], "Out")
|
||||
except:
|
||||
print("do not support cpu test, skip")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue