|
|
|
@ -68,29 +68,6 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class TransposeINT8MKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
|
|
|
|
|
std::vector<int> axis = ctx.Attr<std::vector<int>>("axis");
|
|
|
|
|
std::vector<int> axis_int8 = {0, 2, 3, 1};
|
|
|
|
|
if (axis.size() != 1) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(axis.size(), axis_int8.size());
|
|
|
|
|
for (size_t i = 0; i < axis.size(); i++) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(axis[i], axis_int8[i],
|
|
|
|
|
"Current INT8 MKLDNN Transpose kernel only surpport "
|
|
|
|
|
"axis with [0, 2, 3, 1] due to MKL-DNN kernel "
|
|
|
|
|
"implementation.");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto* input = ctx.Input<Tensor>("X");
|
|
|
|
|
auto* output = ctx.Output<Tensor>("Out");
|
|
|
|
|
output->ShareDataWith(*input);
|
|
|
|
|
output->set_layout(DataLayout::kMKLDNN);
|
|
|
|
|
output->set_format(input->format());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -148,9 +125,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_KERNEL(transpose2, MKLDNN, ::paddle::platform::CPUPlace,
|
|
|
|
|
ops::TransposeMKLDNNOpKernel<float>,
|
|
|
|
|
ops::TransposeINT8MKLDNNOpKernel<uint8_t>,
|
|
|
|
|
ops::TransposeINT8MKLDNNOpKernel<int8_t>);
|
|
|
|
|
ops::TransposeMKLDNNOpKernel<float>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_KERNEL(transpose, MKLDNN, ::paddle::platform::CPUPlace,
|
|
|
|
|
ops::TransposeMKLDNNOpKernel<float>);
|
|
|
|
|