|
|
@ -86,16 +86,16 @@ class TransposeOp : public framework::OperatorWithKernel {
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
|
|
|
|
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
|
|
|
|
|
|
|
|
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
this->CanMKLDNNBeUsed(ctx)) {
|
|
|
|
this->CanMKLDNNBeUsed(ctx, data_type)) {
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
layout_ = framework::DataLayout::kMKLDNN;
|
|
|
|
layout_ = framework::DataLayout::kMKLDNN;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
return framework::OpKernelType(
|
|
|
|
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_,
|
|
|
|
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
|
|
|
|
library_);
|
|
|
|
layout_, library_);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -184,16 +184,17 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
|
|
|
|
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
|
|
|
|
|
|
|
|
auto data_type = OperatorWithKernel::IndicateVarDataType(
|
|
|
|
|
|
|
|
ctx, framework::GradVarName("Out"));
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
this->CanMKLDNNBeUsed(ctx)) {
|
|
|
|
this->CanMKLDNNBeUsed(ctx, data_type)) {
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
layout_ = framework::DataLayout::kMKLDNN;
|
|
|
|
layout_ = framework::DataLayout::kMKLDNN;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
|
|
|
|
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_,
|
|
|
|
ctx, framework::GradVarName("Out")),
|
|
|
|
library_);
|
|
|
|
ctx.GetPlace(), layout_, library_);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -231,9 +232,11 @@ class Transpose2Op : public TransposeOp {
|
|
|
|
int customized_type_value =
|
|
|
|
int customized_type_value =
|
|
|
|
framework::OpKernelType::kDefaultCustomizedTypeValue;
|
|
|
|
framework::OpKernelType::kDefaultCustomizedTypeValue;
|
|
|
|
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
|
|
|
|
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
|
|
|
|
|
|
|
|
framework::proto::VarType::Type data_type =
|
|
|
|
|
|
|
|
OperatorWithKernel::IndicateVarDataType(ctx, "X");
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
this->CanMKLDNNBeUsed(ctx)) {
|
|
|
|
this->CanMKLDNNBeUsed(ctx, data_type)) {
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
layout_ = framework::DataLayout::kMKLDNN;
|
|
|
|
layout_ = framework::DataLayout::kMKLDNN;
|
|
|
|
using framework::proto::VarType;
|
|
|
|
using framework::proto::VarType;
|
|
|
@ -244,9 +247,8 @@ class Transpose2Op : public TransposeOp {
|
|
|
|
: kTransposeMKLDNNFP32;
|
|
|
|
: kTransposeMKLDNNFP32;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
return framework::OpKernelType(
|
|
|
|
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_,
|
|
|
|
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
|
|
|
|
customized_type_value);
|
|
|
|
layout_, library_, customized_type_value);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -310,16 +312,18 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
framework::LibraryType library_{framework::LibraryType::kPlain};
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
std::string data_format = ctx.Attr<std::string>("data_format");
|
|
|
|
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
|
|
|
|
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
|
|
|
|
|
|
|
|
framework::proto::VarType::Type data_type =
|
|
|
|
|
|
|
|
OperatorWithKernel::IndicateVarDataType(ctx,
|
|
|
|
|
|
|
|
framework::GradVarName("Out"));
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
if (library_ == framework::LibraryType::kPlain &&
|
|
|
|
this->CanMKLDNNBeUsed(ctx)) {
|
|
|
|
this->CanMKLDNNBeUsed(ctx, data_type)) {
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
library_ = framework::LibraryType::kMKLDNN;
|
|
|
|
layout_ = framework::DataLayout::kMKLDNN;
|
|
|
|
layout_ = framework::DataLayout::kMKLDNN;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
|
|
|
|
return framework::OpKernelType(data_type, ctx.GetPlace(), layout_,
|
|
|
|
ctx, framework::GradVarName("Out")),
|
|
|
|
library_);
|
|
|
|
ctx.GetPlace(), layout_, library_);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|