Add oneDNN fusion_gru kernel (#25594)

* Add oneDNN fusion_gru kernel and fix fc+gru pass
test=develop

* Formatting changes
test=develop

* Lint fixes
test=develop

* Add memory::format_tag::any to GRU weights
test=develop

* Fix build with CUDA

* Fix build with CUDA v2
revert-24895-update_cub
Adam 5 years ago committed by GitHub
parent 0cb60c700d
commit 68c6160e63
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -118,7 +118,7 @@ function(op_library TARGET)
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op")
"multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()

@ -7,7 +7,12 @@ register_operators(EXCLUDES
fused_fc_elementwise_layernorm_op
multihead_matmul_op
fused_embedding_eltwise_layernorm_op
fusion_group_op)
fusion_group_op
fusion_gru_op)
# fusion_gru_op does not have CUDA kernel
op_library(fusion_gru_op)
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(fusion_gru);\n")
if (WITH_GPU)
# fused_bn_activation_op needs cudnn 7.4.1 above

@ -19,6 +19,9 @@ limitations under the License. */
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc.h"
#include "paddle/fluid/operators/math/sequence2batch.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
@ -122,8 +125,17 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context());
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
library);
}
void FusionGRUOpMaker::Make() {
@ -187,6 +199,9 @@ void FusionGRUOpMaker::Make() {
"bool"
"use origin mode in article https://arxiv.org/abs/1412.3555")
.SetDefault(false);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddComment(R"DOC(
The Fusion complete GRU Operator.
This operator fuse the fully-connected operator into GRU,

File diff suppressed because it is too large Load Diff

@ -181,6 +181,8 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat(
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
return mkldnn::memory::format_tag::ncw;
} else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
return mkldnn::memory::format_tag::ntc;
} else {
return mkldnn::memory::format_tag::nwc;
}
@ -420,5 +422,7 @@ inline std::vector<std::vector<int64_t>> ToMkldnnPadding(
}
}
enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP };
} // namespace platform
} // namespace paddle

@ -0,0 +1,78 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from paddle.fluid.tests.unittests.test_fusion_gru_op import TestFusionGRUOp
class TestFusionGRUMKLDNNOp(TestFusionGRUOp):
def set_confs(self):
self.use_mkldnn = True
class TestFusionGRUMKLDNNOpNoInitial(TestFusionGRUOp):
def set_confs(self):
self.with_h0 = False
self.use_mkldnn = True
class TestFusionGRUMKLDNNOpNoBias(TestFusionGRUOp):
def set_confs(self):
self.with_bias = False
self.use_mkldnn = True
class TestFusionGRUMKLDNNOpReverse(TestFusionGRUOp):
def set_confs(self):
self.is_reverse = True
self.use_mkldnn = True
class TestFusionGRUMKLDNNOpOriginMode(TestFusionGRUOp):
def set_confs(self):
self.origin_mode = True
self.use_mkldnn = True
class TestFusionGRUMKLDNNOpMD1(TestFusionGRUOp):
def set_confs(self):
self.M = 36
self.D = 8
self.use_mkldnn = True
class TestFusionGRUMKLDNNOpMD2(TestFusionGRUOp):
def set_confs(self):
self.M = 8
self.D = 8
self.use_mkldnn = True
class TestFusionGRUMKLDNNOpMD3(TestFusionGRUOp):
def set_confs(self):
self.M = 17
self.D = 15
self.use_mkldnn = True
class TestFusionGRUMKLDNNOpBS1(TestFusionGRUOp):
def set_confs(self):
self.lod = [[3]]
self.D = 16
self.use_mkldnn = True
if __name__ == "__main__":
unittest.main()

@ -30,6 +30,7 @@ def fusion_gru(
wh, # D x 3D
bias, # 1 x 3D
is_reverse,
origin_mode,
act_state,
act_gate):
return gru(fc(x, wx, bias),
@ -40,7 +41,8 @@ def fusion_gru(
(1, wh.shape[1]), dtype='float32'),
is_reverse,
act_state,
act_gate)
act_gate,
origin_mode=origin_mode)
class TestFusionGRUOp(OpTest):
@ -57,6 +59,8 @@ class TestFusionGRUOp(OpTest):
self.with_bias = True
self.act_state = 'tanh'
self.act_gate = 'sigmoid'
self.origin_mode = False
self.use_mkldnn = False
self.set_confs()
T = sum(self.lod[0])
@ -73,7 +77,7 @@ class TestFusionGRUOp(OpTest):
(N, self.D), dtype='float32')
_, _, _, hidden = fusion_gru(
x, self.lod, h0, wx, wh, bias, self.is_reverse,
x, self.lod, h0, wx, wh, bias, self.is_reverse, self.origin_mode,
ACTIVATION[self.act_state], ACTIVATION[self.act_gate])
self.inputs = {'X': (x, self.lod), 'WeightX': wx, 'WeightH': wh}
@ -89,7 +93,9 @@ class TestFusionGRUOp(OpTest):
self.attrs = {
'activation': self.act_state,
'gate_activation': self.act_gate,
'is_reverse': self.is_reverse
'is_reverse': self.is_reverse,
'origin_mode': self.origin_mode,
'use_mkldnn': self.use_mkldnn
}
def test_check_output(self):

Loading…
Cancel
Save