You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
121 lines
4.3 KiB
121 lines
4.3 KiB
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
#include "paddle/fluid/operators/softmax_op.h"
|
|
#include "paddle/fluid/platform/cudnn_helper.h"
|
|
|
|
namespace paddle {
|
|
namespace platform {
|
|
struct CUDAPlace;
|
|
struct float16;
|
|
} // namespace platform
|
|
} // namespace paddle
|
|
|
|
namespace paddle {
|
|
namespace operators {
|
|
|
|
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
|
|
using DataLayout = platform::DataLayout;
|
|
using Tensor = framework::Tensor;
|
|
|
|
static inline int SizeOutAxis(const int axis, DDim dims) {
|
|
int size = 1;
|
|
for (int i = axis + 1; i < dims.size(); i++) {
|
|
size *= dims[i];
|
|
}
|
|
return size;
|
|
}
|
|
|
|
template <typename T>
|
|
class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
|
|
public:
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
auto* x = ctx.Input<Tensor>("X");
|
|
auto* out = ctx.Output<Tensor>("Out");
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
auto* out_data = out->data<T>();
|
|
|
|
auto dims = x->dims();
|
|
const int rank = dims.size();
|
|
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
|
|
const int dim = dims[axis];
|
|
const int N = SizeToAxis(axis, dims);
|
|
const int D = SizeOutAxis(axis, dims);
|
|
|
|
ScopedTensorDescriptor desc;
|
|
std::vector<int> tensor_dims = {N, dim, D, 1};
|
|
DataLayout layout = DataLayout::kNCHW;
|
|
cudnnTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
auto handle = dev_ctx.cudnn_handle();
|
|
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
|
|
: CUDNN_SOFTMAX_MODE_CHANNEL;
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward(
|
|
handle, CUDNN_SOFTMAX_ACCURATE, mode,
|
|
platform::CudnnDataType<T>::kOne(), desc_, x->data<T>(),
|
|
platform::CudnnDataType<T>::kZero(), desc_, out_data));
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
|
|
public:
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
auto* out = ctx.Input<Tensor>("Out");
|
|
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
dx->mutable_data<T>(ctx.GetPlace());
|
|
auto* dx_data = dx->data<T>();
|
|
|
|
auto dims = out->dims();
|
|
const int rank = dims.size();
|
|
const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
|
|
const int dim = dims[axis];
|
|
const int N = SizeToAxis(axis, dims);
|
|
const int D = SizeOutAxis(axis, dims);
|
|
|
|
ScopedTensorDescriptor desc;
|
|
std::vector<int> tensor_dims = {N, dim, D, 1};
|
|
DataLayout layout = DataLayout::kNCHW;
|
|
cudnnTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
auto handle = dev_ctx.cudnn_handle();
|
|
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
|
|
: CUDNN_SOFTMAX_MODE_CHANNEL;
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward(
|
|
handle, CUDNN_SOFTMAX_ACCURATE, mode,
|
|
platform::CudnnDataType<T>::kOne(), desc_, out->data<T>(), desc_,
|
|
dout->data<T>(), platform::CudnnDataType<T>::kZero(), desc_, dx_data));
|
|
}
|
|
};
|
|
|
|
} // namespace operators
|
|
} // namespace paddle
|
|
|
|
namespace ops = paddle::operators;
|
|
namespace plat = paddle::platform;
|
|
REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
|
|
ops::SoftmaxCUDNNKernel<float>,
|
|
ops::SoftmaxCUDNNKernel<double>,
|
|
ops::SoftmaxCUDNNKernel<plat::float16>);
|
|
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
|
|
ops::SoftmaxGradCUDNNKernel<float>,
|
|
ops::SoftmaxGradCUDNNKernel<double>,
|
|
ops::SoftmaxGradCUDNNKernel<plat::float16>);
|