|
|
|
@ -44,7 +44,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (platform::CanCUDNNBeUsed(ctx)) {
|
|
|
|
|
library = framework::LibraryType::kCUDNN;
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
@ -118,7 +118,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (platform::CanCUDNNBeUsed(ctx)) {
|
|
|
|
|
library = framework::LibraryType::kCUDNN;
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|