|
|
@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
|
|
|
|
limitations under the License. */
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/softmax_op.h"
|
|
|
|
#include "paddle/fluid/operators/softmax_op.h"
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
|
|
|
#include "paddle/fluid/platform/cudnn_helper.h"
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
@ -38,19 +41,12 @@ class SoftmaxOp : public framework::OperatorWithKernel {
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
// choose cudnn kernel if the runtime supported.
|
|
|
|
// choose cudnn kernel if the runtime supported.
|
|
|
|
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
bool runtime_cudnn_support = false;
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
if (platform::CanCUDNNBeUsed(ctx)) {
|
|
|
|
auto& dev_ctx =
|
|
|
|
library = framework::LibraryType::kCUDNN;
|
|
|
|
ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
|
|
|
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
framework::LibraryType library_ = framework::LibraryType::kPlain;
|
|
|
|
|
|
|
|
if (use_cudnn && runtime_cudnn_support) {
|
|
|
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
return framework::OpKernelType(
|
|
|
|
return framework::OpKernelType(
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
|
|
|
@ -119,19 +115,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
// choose cudnn kernel if the runtime supported.
|
|
|
|
// choose cudnn kernel if the runtime supported.
|
|
|
|
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
bool runtime_cudnn_support = false;
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
if (platform::CanCUDNNBeUsed(ctx)) {
|
|
|
|
auto& dev_ctx =
|
|
|
|
library = framework::LibraryType::kCUDNN;
|
|
|
|
ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
|
|
|
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
framework::LibraryType library_ = framework::LibraryType::kPlain;
|
|
|
|
|
|
|
|
if (use_cudnn && runtime_cudnn_support) {
|
|
|
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
return framework::OpKernelType(
|
|
|
|
return framework::OpKernelType(
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
|
|
|
|