From 68c6160e639be38c57a7dd831f7b841b33e92676 Mon Sep 17 00:00:00 2001 From: Adam <38704900+grygielski@users.noreply.github.com> Date: Thu, 6 Aug 2020 09:57:27 +0200 Subject: [PATCH] Add oneDNN fusion_gru kernel (#25594) * Add oneDNN fusion_gru kernel and fix fc+gru pass test=develop * Formatting changes test=develop * Lint fixes test=develop * Add memory::format_tag::any to GRU weights test=develop * Fix build with CUDA * Fix build with CUDA v2 --- cmake/operators.cmake | 2 +- paddle/fluid/operators/fused/CMakeLists.txt | 7 +- paddle/fluid/operators/fused/fusion_gru_op.cc | 17 +- .../fused/mkldnn/fusion_gru_mkldnn_op.cc | 439 ++++++++++++++++++ paddle/fluid/platform/mkldnn_helper.h | 4 + .../mkldnn/test_fusion_gru_mkldnn_op.py | 78 ++++ .../tests/unittests/test_fusion_gru_op.py | 12 +- 7 files changed, 553 insertions(+), 6 deletions(-) create mode 100644 paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_mkldnn_op.py diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 5b03cbf8c7..ecf2dbc817 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -118,7 +118,7 @@ function(op_library TARGET) "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" -"multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op") +"multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 24f656140f..3fc5f3bfc6 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -7,7 +7,12 @@ register_operators(EXCLUDES fused_fc_elementwise_layernorm_op multihead_matmul_op fused_embedding_eltwise_layernorm_op - fusion_group_op) + fusion_group_op + fusion_gru_op) + +# fusion_gru_op does not have CUDA kernel +op_library(fusion_gru_op) +file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(fusion_gru);\n") if (WITH_GPU) # fused_bn_activation_op needs cudnn 7.4.1 above diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index f6c8316e2e..d0920098f6 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -19,6 +19,9 @@ limitations under the License. */ #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/math/sequence2batch.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace operators { @@ -122,8 +125,17 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType FusionGRUOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; +#ifdef PADDLE_WITH_MKLDNN + if (platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + } +#endif return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context()); + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout, + library); } void FusionGRUOpMaker::Make() { @@ -187,6 +199,9 @@ void FusionGRUOpMaker::Make() { "bool" "use origin mode in article https://arxiv.org/abs/1412.3555") .SetDefault(false); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(R"DOC( The Fusion complete GRU Operator. This operator fuse the fully-connected operator into GRU, diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc new file mode 100644 index 0000000000..3940aae53b --- /dev/null +++ b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc @@ -0,0 +1,439 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fusion_gru_op.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using paddle::framework::LoDTensor; +using paddle::framework::Tensor; +using paddle::platform::CPUDeviceContext; +using paddle::platform::MKLDNNGetDataType; +using paddle::platform::MKLDNNMemDesc; +using platform::to_void_cast; + +template +class GRUMKLDNNHandler : public platform::MKLDNNHandlerT { + public: + GRUMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, + const platform::MKLDNNDeviceContext& dev_ctx, + const mkldnn::engine mkldnn_engine, + platform::Place cpu_place, const LoDTensor* input, + const Tensor* weight_h, const Tensor* h0, + const bool is_reverse, const int64_t N, const int64_t Ti, + const int64_t IC, const int64_t OC, + const std::string& unique_name) + : platform::MKLDNNHandlerT( + dev_ctx, dev_ctx.GetEngine(), cpu_place, + platform::CreateKey(unique_name, Ti)), + N(N), + Ti(Ti), + IC(IC), + OC(OC) { + // Create memory key without Ti because weights, bias and h0 memories + // do not depend on Ti size but primitive and input/output memory do + if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != + platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { + memory_key_ = unique_name; + } else { + memory_key_ = unique_name + "-t:" + platform::ThreadIDasStr(); + } + + if (!this->isCached()) { + // oneDNN kernel has hardcoded activation functions + PADDLE_ENFORCE_EQ( + ctx.Attr("gate_activation"), "sigmoid", + platform::errors::Unimplemented( + "oneDNN fusion_gru supports only sigmoid as a gate activation.")); + PADDLE_ENFORCE_EQ( + ctx.Attr("activation"), "tanh", + platform::errors::Unimplemented( + "oneDNN fusion_gru supports only tanh as an activation.")); + + // oneDNN RNN dimensions + const int64_t D = 1; // Directions + const int64_t L = 1; // Layers (PP supports only 1 stacked layer) + const int64_t G = 3; // Number of Gates, 3 for GRU + + // Create memory descriptors + auto input_md = MKLDNNMemDesc({Ti, N, IC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::any); + auto weight_x_md = MKLDNNMemDesc( + {L, D, IC, G, OC}, MKLDNNGetDataType(), MKLDNNMemoryFormat::any); + auto weight_h_md = MKLDNNMemDesc( + {L, D, OC, G, OC}, MKLDNNGetDataType(), MKLDNNMemoryFormat::any); + auto bias_md = MKLDNNMemDesc({L, D, G, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldgo); + auto hidden_md = MKLDNNMemDesc({Ti, N, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::any); + auto h0_md = dnnl::memory::desc(); + if (h0) { + h0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldnc); + } + + // Create GRU oneDNN primitive + const auto direction = + is_reverse ? dnnl::rnn_direction::unidirectional_right2left + : dnnl::rnn_direction::unidirectional_left2right; + + this->AcquireForwardPrimitiveDescriptor( + dnnl::prop_kind::forward_inference, direction, input_md, h0_md, + weight_x_md, weight_h_md, bias_md, hidden_md, dnnl::memory::desc()); + } + } + + bool is_NTC() { + return (platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc()) == + dnnl::memory::format_tag::ntc); + } + + void reorderRNNdata(const T* input_data, T* output_data, + std::vector lod, const bool is_reverse, + platform::RNNReorderType reorder_type) { + switch (reorder_type) { + // Reorder input memory [WORDS, C] + LoD -> [N, T, C] + case platform::RNNReorderType::PP_NTC: { + auto* input_data_iter = input_data; + for (int n = 0; n < N; ++n) { + const auto num_elements = (lod[n + 1] - lod[n]) * IC; + const auto offset = is_reverse ? (Ti * IC - num_elements) : 0; + memcpy(output_data + n * Ti * IC + offset, input_data_iter, + sizeof(T) * num_elements); + input_data_iter += num_elements; + } + } break; + // Reorder input memory [WORDS, C] + LoD -> [T, N, C] + case platform::RNNReorderType::PP_TNC: { + auto* input_data_iter = input_data; + for (int n = 0; n < N; ++n) { + const auto num_elements = (lod[n + 1] - lod[n]); + const auto offset = is_reverse ? (Ti - num_elements) : 0; + for (size_t t = 0; t < num_elements; ++t) { + memcpy(output_data + (t + offset) * N * IC + n * IC, + input_data_iter, sizeof(T) * IC); + input_data_iter += IC; + } + } + } break; + // Reorder output values to PP format [N, T, C] -> [WORDS, C] + case platform::RNNReorderType::NTC_PP: { + auto* output_data_iter = output_data; + for (int n = 0; n < N; ++n) { + const auto num_elements = (lod[n + 1] - lod[n]) * OC; + const auto offset = is_reverse ? (Ti * OC - num_elements) : 0; + memcpy(output_data_iter, input_data + n * Ti * OC + offset, + sizeof(T) * num_elements); + output_data_iter += num_elements; + } + } break; + // Reorder output values to PP format [T, N, C] -> [WORDS, C] + case platform::RNNReorderType::TNC_PP: { + auto* output_data_iter = output_data; + for (int n = 0; n < N; ++n) { + const auto num_elements = lod[n + 1] - lod[n]; + const auto offset = is_reverse ? (Ti - num_elements) : 0; + for (size_t t = 0; t < num_elements; ++t) { + memcpy(output_data_iter, + input_data + (t + offset) * N * OC + n * OC, sizeof(T) * OC); + output_data_iter += OC; + } + } + } break; + } + } + + std::shared_ptr AcquireInputMemoryWithReorder( + const LoDTensor* input, const bool is_reverse) { + const auto name = this->key_ + "@input_mem"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(name)); + + if (!memory_p) { + memory_p = std::make_shared(this->fwd_pd_->src_desc(), + this->engine_); + this->dev_ctx_.SetBlob(name, memory_p); + } + + const auto& input_lod = input->lod()[0]; + auto* x_data = input->data(); + + auto* x_onednn_data = reinterpret_cast(memory_p->get_data_handle()); + memset(x_onednn_data, 0, sizeof(T) * N * Ti * IC); + + if (platform::GetMKLDNNFormat(this->fwd_pd_->src_desc()) == + dnnl::memory::format_tag::ntc) { + reorderRNNdata(x_data, x_onednn_data, input_lod, is_reverse, + platform::RNNReorderType::PP_NTC); + } else { + reorderRNNdata(x_data, x_onednn_data, input_lod, is_reverse, + platform::RNNReorderType::PP_TNC); + } + return memory_p; + } + + std::shared_ptr AcquireOutputMemory() { + const auto name = this->key_ + "@output_mem"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(name)); + + if (!memory_p) { + memory_p = std::make_shared(this->fwd_pd_->dst_desc(), + this->engine_); + this->dev_ctx_.SetBlob(name, memory_p); + } + return memory_p; + } + + std::shared_ptr AcquireH0Memory(const Tensor* h0) { + const std::string h0_key = memory_key_ + "@h0"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(h0_key)); + + auto* h0_data = to_void_cast(h0->data()); + + if (!memory_p) { + memory_p = std::make_shared( + this->fwd_pd_->weights_layer_desc(), this->engine_, h0_data); + this->dev_ctx_.SetBlob(h0_key, memory_p); + } else { + memory_p->set_data_handle(h0_data); + } + return memory_p; + } + + std::shared_ptr AcquireWeightXMemory(const Tensor* weight_x, + const bool origin_mode) { + const std::string wx_key = memory_key_ + "@weight_x"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(wx_key)); + + if (!memory_p) { + auto user_md = + MKLDNNMemDesc({1, 1, IC, 3, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldigo); + auto user_memory = dnnl::memory(user_md, this->engine_); + + auto* weight_x_data = + reinterpret_cast(user_memory.get_data_handle()); + memcpy(weight_x_data, weight_x->data(), + sizeof(float) * IC * 3 * OC); + + if (origin_mode == false) { + for (int64_t i = 0; i < IC; ++i) { + for (int64_t j = 0; j < OC; ++j) { + weight_x_data[j] *= -1; + } + weight_x_data += 3 * OC; + } + } + + memory_p = std::make_shared( + this->fwd_pd_->weights_layer_desc(), this->engine_); + + dnnl::stream astream(this->engine_); + dnnl::reorder(user_memory, *memory_p) + .execute(astream, user_memory, *memory_p); + + this->dev_ctx_.SetBlob(wx_key, memory_p); + } + return memory_p; + } + + std::shared_ptr AcquireWeightHMemory(const Tensor* weight_h, + const bool origin_mode) { + const std::string wh_key = memory_key_ + "@weight_h"; + auto memory_p = + std::static_pointer_cast(this->dev_ctx_.GetBlob(wh_key)); + + if (!memory_p) { + auto user_md = + MKLDNNMemDesc({1, 1, OC, 3, OC}, MKLDNNGetDataType(), + MKLDNNMemoryFormat::ldigo); + auto user_memory = dnnl::memory(user_md, this->engine_); + + // Reorder weights_h from PP format [OC, 2OC] + [OC, OC] to + // oneDNN format [OC, 3OC] + auto* weight_h_data = + reinterpret_cast(user_memory.get_data_handle()); + auto* user_weight_h_data = weight_h->data(); + + auto src1_iter = user_weight_h_data; + auto src2_iter = user_weight_h_data + 2 * OC * OC; + + for (int64_t c = 0; c < OC; ++c) { + memcpy(weight_h_data, src1_iter, 2 * OC * sizeof(float)); + memcpy(weight_h_data + 2 * OC, src2_iter, OC * sizeof(float)); + + src1_iter += 2 * OC; + src2_iter += OC; + weight_h_data += 3 * OC; + } + + weight_h_data = reinterpret_cast(user_memory.get_data_handle()); + + if (origin_mode == false) { + for (int64_t i = 0; i < OC; ++i) { + for (int64_t j = 0; j < OC; ++j) { + weight_h_data[j] *= -1; + } + weight_h_data += 3 * OC; + } + } + + memory_p = std::make_shared( + this->fwd_pd_->weights_iter_desc(), this->engine_); + + dnnl::stream astream(this->engine_); + dnnl::reorder(user_memory, *memory_p) + .execute(astream, user_memory, *memory_p); + + this->dev_ctx_.SetBlob(wh_key, memory_p); + } + return memory_p; + } + + std::shared_ptr AcquireBiasMemory(const Tensor* bias, + const bool origin_mode) { + const std::string bias_key = memory_key_ + "@bias"; + auto memory_p = std::static_pointer_cast( + this->dev_ctx_.GetBlob(bias_key)); + + if (!memory_p) { + memory_p = std::make_shared(this->fwd_pd_->bias_desc(), + this->engine_); + auto* bias_data = reinterpret_cast(memory_p->get_data_handle()); + if (bias) { + const float* user_bias_data = + bias->data(); // Bias in oneDNN is always float + memcpy(bias_data, user_bias_data, sizeof(float) * 3 * OC); + } else { + // oneDNN always need bias memory, if it's not provided in PP, let + // oneDNN allocate memory and set it to 0 + memset(bias_data, 0, sizeof(float) * 3 * OC); + } + + if (origin_mode == false && bias) { + for (int64_t i = 0; i < OC; ++i) { + bias_data[i] *= -1; + } + } + this->dev_ctx_.SetBlob(bias_key, memory_p); + } + return memory_p; + } + + private: + // RNN dimensions + // N - Batch Size + // Ti - Max sentence length + // IC - Input Channels + // OC - Output Channels + const int64_t N, Ti, IC, OC; + + // Memory size of weights, bias and h0 does not depend + // on Ti size, thus we need another key to cache them + std::string memory_key_; +}; + +template +class FusionGRUMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = + ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + // Get Tensors + const auto* input = ctx.Input("X"); + const auto* h0 = ctx.Input("H0"); + const auto* weight_x = ctx.Input("WeightX"); + const auto* weight_h = ctx.Input("WeightH"); + const auto* bias = ctx.Input("Bias"); + auto* hidden = ctx.Output("Hidden"); + + // Get attributes + const bool is_reverse = ctx.Attr("is_reverse"); + const bool origin_mode = ctx.Attr("origin_mode"); + + // Get tensor dimensions + const auto x_dims = framework::vectorize(input->dims()); + const auto weight_h_dims = framework::vectorize(weight_h->dims()); + const auto& input_lod = input->lod()[0]; + + // Calculate RNN dimensions + const int64_t N = input_lod.size() - 1; // Number of sentences (batches) + const int64_t Ti = // Max length of the sentence in a batch + [&input_lod]() { + size_t res = 0; + for (size_t i = 0; i < (input_lod.size() - 1); ++i) { + res = std::max(res, input_lod[i + 1] - input_lod[i]); + } + return res; + }(); + const int64_t IC = x_dims[1]; // Input channels + const int64_t OC = weight_h_dims[0]; // Output channels + + GRUMKLDNNHandler handler(ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), + input, weight_h, h0, is_reverse, N, Ti, IC, OC, + ctx.InputName("X") + ctx.InputName("WeightH")); + + auto input_memory_p = + handler.AcquireInputMemoryWithReorder(input, is_reverse); + auto weight_x_memory_p = + handler.AcquireWeightXMemory(weight_x, origin_mode); + auto weight_h_memory_p = + handler.AcquireWeightHMemory(weight_h, origin_mode); + auto bias_memory_p = handler.AcquireBiasMemory(bias, origin_mode); + auto hidden_onednn_memory_p = handler.AcquireOutputMemory(); + + std::unordered_map gru_args = { + {DNNL_ARG_SRC_LAYER, *input_memory_p}, + {DNNL_ARG_WEIGHTS_LAYER, *weight_x_memory_p}, + {DNNL_ARG_WEIGHTS_ITER, *weight_h_memory_p}, + {DNNL_ARG_BIAS, *bias_memory_p}, + {DNNL_ARG_DST_LAYER, *hidden_onednn_memory_p}}; + + if (h0) { + auto h0_memory_p = handler.AcquireH0Memory(h0); + gru_args.insert({DNNL_ARG_SRC_ITER, *h0_memory_p}); + } + + auto gru_forward_p = handler.AcquireForwardPrimitive(); + + dnnl::stream astream(mkldnn_engine); + gru_forward_p->execute(astream, gru_args); + astream.wait(); + + auto* hidden_onednn_data = + reinterpret_cast(hidden_onednn_memory_p->get_data_handle()); + auto* hidden_data = hidden->mutable_data(ctx.GetPlace()); + if (handler.is_NTC()) { + handler.reorderRNNdata(hidden_onednn_data, hidden_data, input_lod, + is_reverse, platform::RNNReorderType::NTC_PP); + } else { + handler.reorderRNNdata(hidden_onednn_data, hidden_data, input_lod, + is_reverse, platform::RNNReorderType::TNC_PP); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(fusion_gru, MKLDNN, paddle::platform::CPUPlace, + ops::FusionGRUMKLDNNKernel); diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index c147bdccbe..60588d89db 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -181,6 +181,8 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat( if (inner_nblks == 0) { if (strides[0] >= strides[1] && strides[1] >= strides[2]) { return mkldnn::memory::format_tag::ncw; + } else if (strides[1] >= strides[0] && strides[0] >= strides[2]) { + return mkldnn::memory::format_tag::ntc; } else { return mkldnn::memory::format_tag::nwc; } @@ -420,5 +422,7 @@ inline std::vector> ToMkldnnPadding( } } +enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP }; + } // namespace platform } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_mkldnn_op.py new file mode 100644 index 0000000000..cfbbf7de22 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_fusion_gru_mkldnn_op.py @@ -0,0 +1,78 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.test_fusion_gru_op import TestFusionGRUOp + + +class TestFusionGRUMKLDNNOp(TestFusionGRUOp): + def set_confs(self): + self.use_mkldnn = True + + +class TestFusionGRUMKLDNNOpNoInitial(TestFusionGRUOp): + def set_confs(self): + self.with_h0 = False + self.use_mkldnn = True + + +class TestFusionGRUMKLDNNOpNoBias(TestFusionGRUOp): + def set_confs(self): + self.with_bias = False + self.use_mkldnn = True + + +class TestFusionGRUMKLDNNOpReverse(TestFusionGRUOp): + def set_confs(self): + self.is_reverse = True + self.use_mkldnn = True + + +class TestFusionGRUMKLDNNOpOriginMode(TestFusionGRUOp): + def set_confs(self): + self.origin_mode = True + self.use_mkldnn = True + + +class TestFusionGRUMKLDNNOpMD1(TestFusionGRUOp): + def set_confs(self): + self.M = 36 + self.D = 8 + self.use_mkldnn = True + + +class TestFusionGRUMKLDNNOpMD2(TestFusionGRUOp): + def set_confs(self): + self.M = 8 + self.D = 8 + self.use_mkldnn = True + + +class TestFusionGRUMKLDNNOpMD3(TestFusionGRUOp): + def set_confs(self): + self.M = 17 + self.D = 15 + self.use_mkldnn = True + + +class TestFusionGRUMKLDNNOpBS1(TestFusionGRUOp): + def set_confs(self): + self.lod = [[3]] + self.D = 16 + self.use_mkldnn = True + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py b/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py index fb74545425..d8a5816a42 100644 --- a/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_fusion_gru_op.py @@ -30,6 +30,7 @@ def fusion_gru( wh, # D x 3D bias, # 1 x 3D is_reverse, + origin_mode, act_state, act_gate): return gru(fc(x, wx, bias), @@ -40,7 +41,8 @@ def fusion_gru( (1, wh.shape[1]), dtype='float32'), is_reverse, act_state, - act_gate) + act_gate, + origin_mode=origin_mode) class TestFusionGRUOp(OpTest): @@ -57,6 +59,8 @@ class TestFusionGRUOp(OpTest): self.with_bias = True self.act_state = 'tanh' self.act_gate = 'sigmoid' + self.origin_mode = False + self.use_mkldnn = False self.set_confs() T = sum(self.lod[0]) @@ -73,7 +77,7 @@ class TestFusionGRUOp(OpTest): (N, self.D), dtype='float32') _, _, _, hidden = fusion_gru( - x, self.lod, h0, wx, wh, bias, self.is_reverse, + x, self.lod, h0, wx, wh, bias, self.is_reverse, self.origin_mode, ACTIVATION[self.act_state], ACTIVATION[self.act_gate]) self.inputs = {'X': (x, self.lod), 'WeightX': wx, 'WeightH': wh} @@ -89,7 +93,9 @@ class TestFusionGRUOp(OpTest): self.attrs = { 'activation': self.act_state, 'gate_activation': self.act_gate, - 'is_reverse': self.is_reverse + 'is_reverse': self.is_reverse, + 'origin_mode': self.origin_mode, + 'use_mkldnn': self.use_mkldnn } def test_check_output(self):