|
|
|
@ -13,6 +13,12 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/conv_op.h"
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
#include "paddle/fluid/platform/cudnn_helper.h"
|
|
|
|
|
#endif
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -64,22 +70,21 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType ConvOp::GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const {
|
|
|
|
|
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
|
|
|
|
|
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
|
|
|
|
|
if (platform::CanCUDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
framework::LibraryType library_;
|
|
|
|
|
if (use_cudnn) {
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
|
} else {
|
|
|
|
|
library_ = framework::LibraryType::kPlain;
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
|
|
|
|
|
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
|
|
|
|
@ -131,6 +136,9 @@ Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
|
|
|
|
"use_cudnn",
|
|
|
|
|
"(bool, default false) Only used in cudnn kernel, need install cudnn")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddAttr<bool>("use_mkldnn",
|
|
|
|
|
"(bool, default false) Only used in mkldnn kernel")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddAttr<std::string>(
|
|
|
|
|
"data_format",
|
|
|
|
|
"(string, default NCHW) Only used in "
|
|
|
|
@ -224,6 +232,9 @@ Conv3DOpMaker::Conv3DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
|
|
|
|
|
"use_cudnn",
|
|
|
|
|
"(bool, default false) Only used in cudnn kernel, need install cudnn")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddAttr<bool>("use_mkldnn",
|
|
|
|
|
"(bool, default false) Only used in mkldnn kernel")
|
|
|
|
|
.SetDefault(false);
|
|
|
|
|
AddAttr<std::string>(
|
|
|
|
|
"data_format",
|
|
|
|
|
"(string, default NCHW) Only used in "
|
|
|
|
@ -284,23 +295,21 @@ void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const {
|
|
|
|
|
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
|
|
|
|
|
use_cudnn &= platform::is_gpu_place(ctx.GetPlace());
|
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (platform::is_gpu_place(ctx.GetPlace())) {
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
use_cudnn &= dev_ctx.cudnn_handle() != nullptr;
|
|
|
|
|
if (platform::CanCUDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
framework::LibraryType library_;
|
|
|
|
|
if (use_cudnn) {
|
|
|
|
|
library_ = framework::LibraryType::kCUDNN;
|
|
|
|
|
} else {
|
|
|
|
|
library_ = framework::LibraryType::kPlain;
|
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
|
platform::CanMKLDNNBeUsed(ctx)) {
|
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
|
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
|
|
|
|
|
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(ctx.Input<Tensor>("Input")->type()), ctx.GetPlace(),
|
|
|
|
|