@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cstdio>
#include "cub/cub.cuh"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/float16.h"
// set cub base traits in order to handle float16
namespace cub {
template <>
@ -300,6 +300,20 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
}
}
template <typename T, int MaxLength, int BlockSize>
__global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad,
size_t rows, size_t cols, size_t k) {
for (size_t i = 0; i < rows; ++i) {
for (size_t j = 0; j < cols; ++j) {
x_grad[i * cols + j] = 0;
}
for (size_t j = 0; j < k; ++j) {
size_t idx = indices[i * k + j];
x_grad[i * cols + idx] = out_grad[i * k + j];
}
}
}
inline static int GetDesiredBlockDim(int dim) {
if (dim > 128) {
return 256;
@ -478,7 +492,7 @@ bool SortTopk(const platform::CUDADeviceContext& ctx,
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
template <typename T>
template <typename DeviceContext, typename T>
class TopkOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
@ -540,6 +554,42 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class TopkOpGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace."));
auto* x = context.Input<Tensor>("X");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto* indices = context.Input<Tensor>("Indices");
auto* x_grad = context.Output<Tensor>(framework::GradVarName("X"));
T* x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
const T* out_grad_data = out_grad->data<T>();
const int64_t* indices_data = indices->data<int64_t>();
size_t k = indices->dims()[indices->dims().size() - 1];
framework::DDim xdims = x->dims();
const size_t row =
framework::product(framework::slice_ddim(xdims, 0, xdims.size() - 1));
const size_t col = xdims[xdims.size() - 1];
const auto& dev_ctx = context.cuda_device_context();
const int kMaxHeight = 2048;
int gridx = row < kMaxHeight ? row : kMaxHeight;
switch (GetDesiredBlockDim(col)) {
FIXED_BLOCK_DIM(
AssignGrad<T, 5,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
x_grad_data, indices_data, out_grad_data, row, col, k));
default:
PADDLE_THROW(
platform::errors::Unavailable("Error occurs when Assign Grad."));
}
}
};
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM
@ -547,8 +597,27 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
top_k, paddle::operators::TopkOpCUDAKernel<float>,
paddle::operators::TopkOpCUDAKernel<double>,
paddle::operators::TopkOpCUDAKernel<int>,
paddle::operators::TopkOpCUDAKernel<int64_t>,
paddle::operators::TopkOpCUDAKernel<paddle::platform::float16>);
top_k,
paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
float>,
paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
int>,
paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::TopkOpCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
REGISTER_OP_CUDA_KERNEL(
top_k_grad,
paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
float>,
paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
double>,
paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
int>,
paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>,
paddle::operators::TopkOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);