|
|
|
@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/softmax_op.h"
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#include "paddle/fluid/platform/cudnn_helper.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
@ -41,29 +44,30 @@ class SoftmaxOp : public framework::OperatorWithKernel {
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
// choose cudnn kernel if the runtime supported.
|
|
|
|
|
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
|
|
|
|
|
bool runtime_cudnn_support = false;
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
framework::LibraryType library_ = framework::LibraryType::kPlain;
|
|
|
|
|
if (use_cudnn && runtime_cudnn_support) {
|
|
|
|
|
if (platform::CanCUDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
auto input_data_type =
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("X")->type());
|
|
|
|
|
if (input_data_type == framework::proto::VarType::FP16) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
|
|
|
|
|
"float16 can only be used when CUDNN is used");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
|
|
|
|
|
framework::StringToDataLayout(data_format), library_);
|
|
|
|
|
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
|
|
|
|
|
framework::StringToDataLayout(data_format),
|
|
|
|
|
library_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
@ -130,19 +134,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
// choose cudnn kernel if the runtime supported.
|
|
|
|
|
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
|
|
|
|
|
bool runtime_cudnn_support = false;
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|
auto& dev_ctx =
|
|
|
|
|
ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
framework::LibraryType library_ = framework::LibraryType::kPlain;
|
|
|
|
|
if (use_cudnn && runtime_cudnn_support) {
|
|
|
|
|
if (platform::CanCUDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
|
|
|
|
|